You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quantize_util.cc 16 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "mindspore/lite/tools/converter/quantizer/quantize_util.h"
  17. #include <cmath>
  18. #include <string>
  19. #include <algorithm>
  20. #include <memory>
  21. #include <vector>
  22. #include <set>
  23. #include "src/ops/primitive_c.h"
  24. #include "mindspore/lite/tools/converter/quantizer/general_bitpacking.h"
  25. #include "src/common/utils.h"
  26. #include "abstract/abstract_value.h"
  27. #include "securec/include/securec.h"
  28. using std::string;
  29. using std::vector;
  30. namespace mindspore {
  31. namespace lite {
  32. namespace quant {
  33. const std::vector<schema::PrimitiveType> QuantStrategy::conv_types = {
  34. schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_Conv2D,
  35. schema::PrimitiveType_DepthwiseConv2D};
  36. const std::vector<schema::PrimitiveType> QuantStrategy::mul_types = {schema::PrimitiveType_MatMul,
  37. schema::PrimitiveType_FullConnection};
  38. QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThreshold)
  39. : mWeightSize(weightSize), mConvWeightQuantChannelThreshold(convWeightQuantChannelThreshold) {}
  40. bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const {
  41. auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(node->input(0));
  42. if (primitive_c == nullptr) {
  43. MS_LOG(ERROR) << "primitive_c is nullptr";
  44. return false;
  45. }
  46. if (!IsContain(conv_types, (schema::PrimitiveType)primitive_c->Type())) {
  47. return false;
  48. }
  49. if (node->size() < 3) {
  50. return false;
  51. }
  52. auto inputNode = node->input(2);
  53. if (!inputNode->isa<Parameter>()) {
  54. return false;
  55. }
  56. auto paramNode = inputNode->cast<ParameterPtr>();
  57. auto abstract_base = paramNode->abstract();
  58. if (abstract_base == nullptr) {
  59. return false;
  60. }
  61. if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
  62. MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
  63. return false;
  64. }
  65. auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
  66. size_t shapeSize = 1;
  67. for (auto dim : weight_shape) {
  68. shapeSize = shapeSize * dim;
  69. }
  70. if (shapeSize < mWeightSize) {
  71. MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize;
  72. return false;
  73. }
  74. if (weight_shape[0] <= static_cast<int>(mConvWeightQuantChannelThreshold)) {
  75. MS_LOG(INFO) << "channel less mConvWeightQuantChannelThreshold!" << weight_shape[0];
  76. return false;
  77. }
  78. return true;
  79. }
  80. bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
  81. if (!node->isa<CNode>()) {
  82. return false;
  83. }
  84. auto cnode = std::dynamic_pointer_cast<CNode>(node);
  85. auto type = NodePrimitiveType(cnode);
  86. static const std::vector<schema::PrimitiveType> int8OpList = {
  87. schema::PrimitiveType_Nchw2Nhwc,
  88. schema::PrimitiveType_Nhwc2Nchw,
  89. schema::PrimitiveType_Conv2D,
  90. schema::PrimitiveType_DepthwiseConv2D,
  91. schema::PrimitiveType_Add,
  92. schema::PrimitiveType_Pooling,
  93. schema::PrimitiveType_Concat,
  94. schema::PrimitiveType_Split,
  95. schema::PrimitiveType_TupleGetItem,
  96. schema::PrimitiveType_Reshape,
  97. schema::PrimitiveType_FullConnection,
  98. schema::PrimitiveType_MatMul,
  99. schema::PrimitiveType_Crop,
  100. schema::PrimitiveType_DeDepthwiseConv2D,
  101. schema::PrimitiveType_DeConv2D,
  102. schema::PrimitiveType_Activation,
  103. schema::PrimitiveType_TupleGetItem,
  104. };
  105. bool contain = IsContain(int8OpList, type);
  106. if (!contain) {
  107. MS_LOG(INFO) << "not quant, " << cnode->fullname_with_scope()
  108. << " of type: " << schema::EnumNamePrimitiveType(type);
  109. }
  110. return contain;
  111. }
  112. bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const {
  113. auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(node->input(0));
  114. if (primitive_c == nullptr) {
  115. MS_LOG(ERROR) << "primitive_c is nullptr";
  116. return false;
  117. }
  118. if (!IsContain(mul_types, (schema::PrimitiveType)primitive_c->Type())) {
  119. return false;
  120. }
  121. if (node->size() < 3) {
  122. MS_LOG(INFO) << "input size less!";
  123. return false;
  124. }
  125. auto inputNode1 = node->input(1);
  126. auto inputNode2 = node->input(2);
  127. if (inputNode1 == nullptr || inputNode2 == nullptr) {
  128. MS_LOG(INFO) << "mul input is nullptr!";
  129. return false;
  130. }
  131. ParameterPtr paramNode = nullptr;
  132. if (inputNode1->isa<Parameter>()) {
  133. paramNode = inputNode1->cast<ParameterPtr>();
  134. } else if (inputNode2->isa<Parameter>()) {
  135. paramNode = inputNode2->cast<ParameterPtr>();
  136. }
  137. if (paramNode == nullptr) {
  138. MS_LOG(INFO) << "invalid paramNode!";
  139. return false;
  140. }
  141. auto abstract_base = paramNode->abstract();
  142. if (abstract_base == nullptr) {
  143. MS_LOG(INFO) << "abstract is nullptr";
  144. return false;
  145. }
  146. if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
  147. MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
  148. return false;
  149. }
  150. auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
  151. size_t shapeSize = 1;
  152. for (auto dim : weight_shape) {
  153. shapeSize = shapeSize * dim;
  154. }
  155. if (shapeSize < mWeightSize) {
  156. MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize;
  157. return false;
  158. }
  159. return true;
  160. }
  161. STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int quant_max,
  162. int quant_min, int num_bits) {
  163. MS_ASSERT(quantParam != nullptr);
  164. if (mMin > 0.0f) {
  165. MS_LOG(DEBUG) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
  166. mMin = 0.0f;
  167. }
  168. if (mMax < 0.0f) {
  169. MS_LOG(DEBUG) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision";
  170. mMax = 0.0f;
  171. }
  172. if (mMin > mMax) {
  173. MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax;
  174. return RET_PARAM_INVALID;
  175. }
  176. if (mMin == mMax) {
  177. if (mMin != 0.0f) {
  178. MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other";
  179. return RET_ERROR;
  180. }
  181. quantParam->inited = true;
  182. quantParam->min = mMin;
  183. quantParam->max = mMax;
  184. quantParam->scale = 0.0f;
  185. quantParam->zeroPoint = 0;
  186. quantParam->narrowRange = narrowRange;
  187. quantParam->numBits = num_bits;
  188. return RET_OK;
  189. }
  190. auto quantMinFloat = static_cast<double>(quant_min);
  191. auto quantMaxFloat = static_cast<double>(quant_max);
  192. double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
  193. const double zeroPointFromMin = quantMinFloat - mMin / scale;
  194. int zeroPoint = static_cast<int32_t>(std::round(zeroPointFromMin));
  195. // The zero point should always be in the range of quantized value,
  196. // [qmin, qmax].
  197. MS_ASSERT(zeroPoint >= quantMin);
  198. MS_ASSERT(zeroPoint <= quantMax);
  199. quantParam->inited = true;
  200. quantParam->min = mMin;
  201. quantParam->max = mMax;
  202. quantParam->scale = scale;
  203. quantParam->zeroPoint = zeroPoint;
  204. quantParam->narrowRange = narrowRange;
  205. quantParam->numBits = num_bits;
  206. return RET_OK;
  207. }
  208. STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int numBits) {
  209. MS_ASSERT(quantParam != nullptr);
  210. if (mMin > 0.0f) {
  211. MS_LOG(DEBUG) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
  212. mMin = 0.0f;
  213. }
  214. if (mMax < 0.0f) {
  215. MS_LOG(DEBUG) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision";
  216. mMax = 0.0f;
  217. }
  218. if (mMin > mMax) {
  219. MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax;
  220. return RET_PARAM_INVALID;
  221. }
  222. if (mMin == mMax) {
  223. if (mMin != 0.0f) {
  224. MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other";
  225. return RET_ERROR;
  226. }
  227. quantParam->inited = false;
  228. quantParam->min = mMin;
  229. quantParam->max = mMax;
  230. quantParam->scale = 0.0f;
  231. quantParam->zeroPoint = 0;
  232. quantParam->narrowRange = narrowRange;
  233. quantParam->numBits = numBits;
  234. return RET_OK;
  235. }
  236. const int8_t quantMin = std::numeric_limits<int8_t>::min() + (narrowRange ? 1 : 0);
  237. const int8_t quantMax = std::numeric_limits<int8_t>::max();
  238. auto quantMinFloat = static_cast<double>(quantMin);
  239. auto quantMaxFloat = static_cast<double>(quantMax);
  240. double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
  241. const double zeroPointFromMin = quantMinFloat - mMin / scale;
  242. const double zeroPointFromMax = quantMaxFloat - mMax / scale;
  243. const double zpFromMinError = std::abs(quantMinFloat) + std::abs(mMin / scale);
  244. const double zpFromMaxError = std::abs(quantMaxFloat) + std::abs(mMax / scale);
  245. const double zpDouble = zpFromMinError < zpFromMaxError ? zeroPointFromMin : zeroPointFromMax;
  246. int zeroPoint;
  247. if (zpDouble < quantMinFloat) {
  248. zeroPoint = quantMin;
  249. } else if (zpDouble > quantMaxFloat) {
  250. zeroPoint = quantMax;
  251. } else {
  252. zeroPoint = static_cast<int32_t>(std::round(zpDouble));
  253. }
  254. if (std::abs(mMin) == std::abs(mMax)) {
  255. zeroPoint = 0;
  256. }
  257. // The zero point should always be in the range of quantized value,
  258. // [qmin, qmax].
  259. MS_ASSERT(zeroPoint >= quantMin);
  260. MS_ASSERT(zeroPoint <= quantMax);
  261. quantParam->inited = true;
  262. quantParam->min = mMin;
  263. quantParam->max = mMax;
  264. quantParam->scale = scale;
  265. quantParam->zeroPoint = zeroPoint;
  266. quantParam->narrowRange = narrowRange;
  267. quantParam->numBits = numBits;
  268. return RET_OK;
  269. }
  270. STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) {
  271. auto *rawDatas = reinterpret_cast<uint8_t *>(weight);
  272. vector<uint8_t> qDatas(rawDatas, rawDatas + shapeSize);
  273. vector<uint8_t> qDatas_packed;
  274. if (bitNum < 8 && bitNum > 1) {
  275. BitPack weight_bitpack(bitNum);
  276. weight_bitpack.BitPacking(qDatas, qDatas_packed);
  277. if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas_packed[0], shapeSize)) {
  278. MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed";
  279. return RET_ERROR;
  280. }
  281. } else if (bitNum == 8) {
  282. if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas[0], shapeSize)) {
  283. MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas failed";
  284. return RET_ERROR;
  285. }
  286. } else {
  287. MS_LOG(ERROR) << "bitNum must be between 0 and 8 : " << bitNum;
  288. return RET_ERROR;
  289. }
  290. return RET_OK;
  291. }
  292. static bool SearchLowerBound(const std::vector<float> &data, const size_t &index, const float &max_tmp, float *min_tmp,
  293. size_t *min_idx) {
  294. size_t length = data.size();
  295. if (max_tmp - data.at(index) < delta) {
  296. return false;
  297. }
  298. float range_ratio = (data.at(index) - *min_tmp) / (max_tmp - *min_tmp);
  299. float index_ratio = static_cast<float>(index - *min_idx) / (length - *min_idx);
  300. if (index_ratio > 0 && range_ratio / index_ratio > ratio) {
  301. *min_idx = index;
  302. *min_tmp = data.at(index);
  303. }
  304. return true;
  305. }
  306. static bool SearchUpperBound(const std::vector<float> &data, const size_t &index, float *max_tmp, const float &min_tmp,
  307. size_t *max_idx) {
  308. size_t length = data.size();
  309. if (data.at(index) - min_tmp < delta) {
  310. return false;
  311. }
  312. float range_ratio = (*max_tmp - data.at(index)) / (*max_tmp - min_tmp);
  313. float index_ratio = static_cast<float>(index - *max_idx) / (length - *max_idx);
  314. if (index_ratio > 0 && range_ratio / index_ratio > ratio) {
  315. *max_idx = index;
  316. *max_tmp = data.at(index);
  317. }
  318. return true;
  319. }
  320. static float CalPercentile(const std::vector<float> &datas, const int &outlier_percent) {
  321. const int size = datas.size();
  322. float val = outlier_percent / 100.0 * size;
  323. int index = std::ceil(val);
  324. float result = 0.0;
  325. if (index - val > 0) {
  326. result = datas.at(index - 1);
  327. } else {
  328. result = (datas.at(index - 1) + datas.at(index)) / 2;
  329. }
  330. return result;
  331. }
  332. std::pair<float, float> OutlierMethod(std::vector<float> min_datas, std::vector<float> max_datas) {
  333. std::sort(max_datas.begin(), max_datas.end());
  334. std::sort(min_datas.begin(), min_datas.end());
  335. float min_val = CalPercentile(min_datas, percent);
  336. float max_val = CalPercentile(max_datas, 100 - percent);
  337. std::reverse(max_datas.begin(), max_datas.end());
  338. MS_ASSERT(min_val < max_val);
  339. MS_ASSERT(min_datas.size() == max_datas.size());
  340. float min_tmp = min_val;
  341. float max_tmp = max_val;
  342. size_t min_idx = 0;
  343. size_t max_idx = 0;
  344. size_t length = min_datas.size();
  345. for (size_t i = 0; i < length; i++) {
  346. if (!SearchLowerBound(min_datas, i, max_tmp, &min_tmp, &min_idx)) {
  347. break;
  348. }
  349. if (!SearchUpperBound(min_datas, i, &max_tmp, min_tmp, &max_idx)) {
  350. break;
  351. }
  352. }
  353. std::pair<float, float> result{min_tmp, max_tmp};
  354. return result;
  355. }
  356. static std::vector<float> InitClusters(float *data, size_t elem_count, size_t k) {
  357. std::set<float> set_unique{};
  358. for (size_t i = 0; i < elem_count; i++) {
  359. set_unique.emplace(data[i]);
  360. }
  361. std::vector<float> data_unique;
  362. data_unique.assign(set_unique.begin(), set_unique.end());
  363. std::vector<float> clusters{};
  364. if (set_unique.size() < k) {
  365. return clusters;
  366. }
  367. // init cluster
  368. float ratio = static_cast<float>(data_unique.size()) / (k - 1);
  369. std::sort(data_unique.begin(), data_unique.end());
  370. for (size_t i = 0; i < k; i++) {
  371. size_t index = std::floor(i * ratio);
  372. if (i * ratio - index > 0) {
  373. clusters.emplace_back((data_unique[index] + data_unique[index + 1]) / 2);
  374. } else {
  375. clusters.emplace_back(data_unique[index]);
  376. }
  377. }
  378. return clusters;
  379. }
  380. std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epochs, schema::QuantParamT *quantParam) {
  381. std::vector<float> clusters = InitClusters(data, elem_count, k);
  382. std::vector<int8_t> clusters_index{};
  383. double error{0};
  384. if (clusters.size() < k) {
  385. MS_LOG(WARNING) << "K is less than the size of data so KMeans function is not executed.";
  386. return clusters_index;
  387. }
  388. for (size_t epoch = 0; epoch < epochs; epoch++) {
  389. double error_cur{0};
  390. clusters_index.clear();
  391. std::vector<std::vector<float>> clusters_data(clusters.size());
  392. for (size_t i = 0; i < elem_count; i++) {
  393. size_t index = 0;
  394. float min_distance = pow(data[i] - clusters[0], 2);
  395. for (size_t j = 1; j < clusters.size(); j++) {
  396. if (pow(data[i] - clusters[j], 2) < min_distance) {
  397. min_distance = pow(data[i] - clusters[j], 2);
  398. index = j;
  399. }
  400. }
  401. clusters_index.emplace_back(index + INT8_MIN);
  402. clusters_data[index].emplace_back(data[i]);
  403. }
  404. for (size_t j = 0; j < clusters.size(); j++) {
  405. if (clusters_data[j].size() > 0) {
  406. clusters[j] = std::accumulate(clusters_data[j].begin(), clusters_data[j].end(), 0.0) / clusters_data[j].size();
  407. }
  408. }
  409. // compare error
  410. for (size_t j = 0; j < elem_count; j++) {
  411. error_cur += pow(data[j] - clusters[clusters_index[j]], 2);
  412. }
  413. error_cur = pow(error_cur / elem_count, 0.5);
  414. if (std::abs((error_cur - error) / error_cur) < 1e-6) {
  415. break;
  416. }
  417. error = error_cur;
  418. }
  419. // update data
  420. quantParam->clusters = clusters;
  421. return clusters_index;
  422. }
  423. schema::PrimitiveType NodePrimitiveType(CNodePtr cnode) {
  424. if (cnode == nullptr) {
  425. MS_LOG(ERROR) << "cnode is null";
  426. return schema::PrimitiveType_NONE;
  427. }
  428. auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
  429. if (primitive_c == nullptr) {
  430. MS_LOG(ERROR) << "primitive_c is null";
  431. return schema::PrimitiveType_NONE;
  432. }
  433. return (schema::PrimitiveType)primitive_c->Type();
  434. }
  435. } // namespace quant
  436. } // namespace lite
  437. } // namespace mindspore