You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

weight_quantizer.cc 27 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/converter/quantizer/weight_quantizer.h"
  17. #include <list>
  18. #include <string>
  19. #include <vector>
  20. #include <unordered_map>
  21. #include "src/common/common.h"
  22. using std::string;
  23. using std::vector;
  24. namespace mindspore::lite::quant {
  25. WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config) : Quantizer(graph) {
  26. quant_strategy_ = std::make_unique<QuantStrategy>(0, 0);
  27. config_param_ = config;
  28. }
  29. WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config) : Quantizer(graph) {
  30. this->config_file_ = config.configFile;
  31. auto quantSize = config.quantWeightSize;
  32. this->bit_num_ = config.bitNum;
  33. auto convQuantWeightChannelThreshold = config.quantWeightChannel;
  34. quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold);
  35. quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
  36. quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
  37. // parse type_id_
  38. if (this->bit_num_ > 0 && this->bit_num_ <= 8) {
  39. type_id_ = kNumberTypeInt8;
  40. } else if (this->bit_num_ <= 16) {
  41. type_id_ = kNumberTypeInt16;
  42. } else {
  43. MS_LOG(ERROR) << "invalid input bits";
  44. }
  45. }
  46. WeightQuantizer::~WeightQuantizer() {
  47. for (const auto &fp32_output_tensor : fp32_output_tensors_) {
  48. for (const auto &kv : fp32_output_tensor) {
  49. delete kv.second;
  50. }
  51. }
  52. }
  53. STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node,
  54. std::shared_ptr<PrimitiveC> primitive_c) {
  55. // set dtype
  56. param_value->set_tensor_type(type_id_);
  57. auto abstract_base = param_node->abstract();
  58. if (abstract_base == nullptr) {
  59. MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
  60. return RET_ERROR;
  61. }
  62. if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
  63. MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
  64. return RET_ERROR;
  65. }
  66. auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
  67. abstract_tensor->element()->set_type(TypeIdToType(type_id_));
  68. primitive_c->set_quant_type(schema::QuantType_WeightQuant);
  69. return RET_OK;
  70. }
  71. STATUS WeightQuantizer::DoConvQuantize(CNodePtr cnode) {
  72. auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
  73. if (primitive_c == nullptr) {
  74. MS_LOG(ERROR) << "primitive_c is nullptr";
  75. return RET_ERROR;
  76. }
  77. auto input_node = cnode->input(2);
  78. if (!input_node->isa<Parameter>()) {
  79. return RET_ERROR;
  80. }
  81. ParameterPtr param_node;
  82. ParamValueLitePtr param_value;
  83. GetLiteParameter(input_node, &param_node, &param_value);
  84. if (param_node == nullptr || param_value == nullptr) {
  85. MS_LOG(ERROR) << "GetLiteParameter error";
  86. return RET_ERROR;
  87. }
  88. if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) {
  89. MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type();
  90. return RET_ERROR;
  91. }
  92. auto status = RET_ERROR;
  93. if (type_id_ == kNumberTypeInt8) {
  94. status =
  95. QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
  96. } else if (type_id_ == kNumberTypeInt16) {
  97. status =
  98. QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
  99. }
  100. if (status == RET_CONTINUE) {
  101. return RET_OK;
  102. } else if (status != RET_OK) {
  103. MS_LOG(ERROR) << "QuantFilter failed : " << status;
  104. return status;
  105. }
  106. status = SetAbstract(param_value, param_node, primitive_c);
  107. if (status != RET_OK) {
  108. MS_LOG(ERROR) << "SetAbstract failed : " << status;
  109. return RET_ERROR;
  110. }
  111. return RET_OK;
  112. }
  113. STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) {
  114. auto already_quant = false;
  115. ParamValueLitePtr param_value = nullptr;
  116. ParameterPtr param_node = nullptr;
  117. int index = 0;
  118. for (size_t i = 1; i < cnode->size(); i++) {
  119. auto inputNode = cnode->input(i);
  120. if (inputNode->isa<Parameter>()) {
  121. param_node = inputNode->cast<ParameterPtr>();
  122. if ((param_node != nullptr) && param_node->has_default()) {
  123. param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
  124. if ((param_value == nullptr) || (param_value->tensor_size() == 0) || (param_value->tensor_addr() == nullptr)) {
  125. param_value = nullptr;
  126. continue;
  127. } else if (param_value->tensor_type() == mindspore::kNumberTypeInt8 ||
  128. param_value->tensor_type() == mindspore::kNumberTypeInt16) {
  129. MS_LOG(INFO) << "the node: " << cnode->fullname_with_scope() << " input_i: " << i << "has been "
  130. << " quantized";
  131. already_quant = true;
  132. break;
  133. } else if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) {
  134. param_value = nullptr;
  135. continue;
  136. } else {
  137. index = i;
  138. break;
  139. }
  140. }
  141. }
  142. }
  143. if (already_quant) {
  144. return RET_OK;
  145. }
  146. if (param_value == nullptr) {
  147. MS_LOG(ERROR) << "No valid input param node !";
  148. return RET_ERROR;
  149. }
  150. auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
  151. if (primitive_c == nullptr) {
  152. MS_LOG(ERROR) << "primitive_c is nullptr";
  153. return RET_ERROR;
  154. }
  155. auto status = RET_ERROR;
  156. if (type_id_ == kNumberTypeInt8) {
  157. status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
  158. true, index - 1);
  159. } else if (type_id_ == kNumberTypeInt16) {
  160. status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
  161. true, index - 1);
  162. }
  163. if (status == RET_CONTINUE) {
  164. return RET_OK;
  165. } else if (status != RET_OK) {
  166. MS_LOG(ERROR) << "QuantFilter failed : " << status;
  167. return status;
  168. }
  169. status = SetAbstract(param_value, param_node, primitive_c);
  170. if (status != RET_OK) {
  171. MS_LOG(ERROR) << "SetAbstract failed : " << status;
  172. return RET_ERROR;
  173. }
  174. return RET_OK;
  175. }
  176. STATUS WeightQuantizer::DoLstmQuantize(CNodePtr cnode) {
  177. MS_ASSERT(cnode != nullptr);
  178. auto op_name = cnode->fullname_with_scope();
  179. auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
  180. MS_ASSERT(primitive_c != nullptr);
  181. if (cnode->inputs().size() < 4) {
  182. MS_LOG(ERROR) << op_name << " inputs is " << cnode->inputs().size();
  183. return RET_ERROR;
  184. }
  185. auto status = ProcessLstmWeightByIndex(cnode, primitive_c, 2);
  186. if (status != RET_OK) {
  187. MS_LOG(ERROR) << "Process lstm weight i failed.";
  188. return RET_ERROR;
  189. }
  190. status = ProcessLstmWeightByIndex(cnode, primitive_c, 3);
  191. if (status != RET_OK) {
  192. MS_LOG(ERROR) << "Process lstm weight h failed.";
  193. return RET_ERROR;
  194. }
  195. if (cnode->inputs().size() > 4) {
  196. status = ProcessLstmWeightByIndex(cnode, primitive_c, 4);
  197. if (status != RET_OK) {
  198. MS_LOG(ERROR) << "Process lstm bias failed.";
  199. return RET_ERROR;
  200. }
  201. }
  202. return status;
  203. }
  204. STATUS WeightQuantizer::DoGatherQuantize(CNodePtr cnode) {
  205. auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
  206. MS_ASSERT(primitive_c != nullptr);
  207. auto first_input = cnode->input(1);
  208. ParameterPtr param_node;
  209. ParamValueLitePtr param_value;
  210. GetLiteParameter(first_input, &param_node, &param_value);
  211. if (param_node == nullptr || param_value == nullptr || param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
  212. MS_LOG(INFO) << "This Gather op " << cnode->fullname_with_scope() << " can not quant weight";
  213. return RET_OK;
  214. }
  215. if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) {
  216. MS_LOG(INFO) << cnode->fullname_with_scope() << " param cnt: " << param_value->tensor_size() / 4 << " < "
  217. << quant_strategy_->mWeightSize;
  218. return RET_OK;
  219. }
  220. auto status = RET_ERROR;
  221. if (type_id_ == kNumberTypeInt8) {
  222. status =
  223. QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false, 0);
  224. } else if (type_id_ == kNumberTypeInt16) {
  225. status =
  226. QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false, 0);
  227. }
  228. if (status == RET_CONTINUE) {
  229. return RET_OK;
  230. } else if (status != RET_OK) {
  231. MS_LOG(ERROR) << "QuantFilter failed : " << status;
  232. return status;
  233. }
  234. status = SetAbstract(param_value, param_node, primitive_c);
  235. if (status != RET_OK) {
  236. MS_LOG(ERROR) << "SetAbstract failed : " << status;
  237. return RET_ERROR;
  238. }
  239. return RET_OK;
  240. }
  241. STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const std::shared_ptr<PrimitiveC> &primitive_c,
  242. const int &index) {
  243. auto op_name = cnode->fullname_with_scope();
  244. auto weight_i = cnode->input(index);
  245. ParameterPtr param_node;
  246. ParamValueLitePtr param_value;
  247. GetLiteParameter(weight_i, &param_node, &param_value);
  248. if (param_node == nullptr || param_value == nullptr) {
  249. MS_LOG(ERROR) << "GetLiteParameter error";
  250. return RET_ERROR;
  251. }
  252. if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
  253. MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
  254. return RET_OK;
  255. }
  256. if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) {
  257. MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < "
  258. << quant_strategy_->mWeightSize;
  259. return RET_OK;
  260. }
  261. auto status = RET_ERROR;
  262. if (type_id_ == kNumberTypeInt8) {
  263. status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
  264. false, index - 1);
  265. } else if (type_id_ == kNumberTypeInt16) {
  266. status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
  267. false, index - 1);
  268. }
  269. if (status == RET_CONTINUE) {
  270. return RET_OK;
  271. } else if (status != RET_OK) {
  272. MS_LOG(ERROR) << "QuantFilter failed : " << status;
  273. return status;
  274. }
  275. status = SetAbstract(param_value, param_node, primitive_c);
  276. if (status != RET_OK) {
  277. MS_LOG(ERROR) << "SetAbstract failed : " << status;
  278. return RET_ERROR;
  279. }
  280. return RET_OK;
  281. }
  282. constexpr float relative_tolerance = 1e-5;
  283. constexpr float abs_tolerance = 1e-4;
  284. template <typename T>
  285. float CompareOutputData(const std::unordered_map<std::string, mindspore::tensor::MSTensor *> &expected_tensor,
  286. const std::unordered_map<std::string, mindspore::tensor::MSTensor *> &compare_tensor) {
  287. auto valid_data = [](T data) -> bool { return (!std::isnan(data) && !std::isinf(data)); };
  288. float total_mean_error = 0.0f;
  289. int tensor_cnt = expected_tensor.size();
  290. if (tensor_cnt <= 0) {
  291. MS_LOG(ERROR) << "unexpected tensor_cnt: " << tensor_cnt;
  292. return RET_ERROR;
  293. }
  294. for (const auto &exp_tensor_pair : expected_tensor) {
  295. float mean_error = 0.0f;
  296. int error_cnt = 0;
  297. auto exp_tensor_name = exp_tensor_pair.first;
  298. auto exp_tensor = exp_tensor_pair.second;
  299. auto cmp_tensor_find_iter = compare_tensor.find(exp_tensor_name);
  300. if (cmp_tensor_find_iter == compare_tensor.end()) {
  301. MS_LOG(ERROR) << "can not find: " << exp_tensor_name;
  302. return RET_ERROR;
  303. }
  304. auto cmp_tensor = cmp_tensor_find_iter->second;
  305. auto exp_tensor_shape = exp_tensor->shape();
  306. auto cmp_tensor_shape = cmp_tensor->shape();
  307. if (exp_tensor_shape != cmp_tensor_shape) {
  308. MS_LOG(ERROR) << "exp tensor shape not equal to cmp. exp_tensor_elem_cnt: " << exp_tensor->ElementsNum()
  309. << " cmp_tensor_elem_cnt: " << cmp_tensor->ElementsNum();
  310. return RET_ERROR;
  311. }
  312. auto exp_data = static_cast<T *>(exp_tensor->MutableData());
  313. auto cmp_data = static_cast<T *>(cmp_tensor->MutableData());
  314. auto elem_cnt = exp_tensor->ElementsNum();
  315. for (int i = 0; i < elem_cnt; i++) {
  316. if (!valid_data(exp_data[i]) || !valid_data(cmp_data[i])) {
  317. MS_LOG(ERROR) << "data is not valid. exp: " << exp_data[i] << " cmp: " << cmp_data[i] << " index: " << i;
  318. return RET_ERROR;
  319. }
  320. auto tolerance = abs_tolerance + relative_tolerance * fabs(exp_data[i]);
  321. auto abs_error = std::fabs(exp_data[i] - cmp_data[i]);
  322. if (abs_error > tolerance) {
  323. if (fabs(exp_data[i] == 0)) {
  324. if (abs_error > 1e-5) {
  325. mean_error += abs_error;
  326. error_cnt++;
  327. } else {
  328. // it is ok, very close to 0
  329. continue;
  330. }
  331. } else {
  332. mean_error += abs_error / (fabs(exp_data[i]) + FLT_MIN);
  333. error_cnt++;
  334. }
  335. } else {
  336. // it is ok, no error
  337. continue;
  338. }
  339. } // end one tensor data loop
  340. total_mean_error += mean_error / elem_cnt;
  341. } // end tensor loop
  342. return total_mean_error / tensor_cnt;
  343. }
  344. STATUS WeightQuantizer::RunFp32Graph(FuncGraphPtr func_graph) {
  345. auto image_cnt = images_.at(0).size();
  346. if (!config_param_.input_shapes.empty()) {
  347. if (config_param_.input_shapes.size() != image_cnt) {
  348. MS_LOG(ERROR) << "input_shapes size: " << config_param_.input_shapes.size() << " image_cnt: " << image_cnt;
  349. return RET_ERROR;
  350. }
  351. }
  352. // 0.1 Create Fp32 Session
  353. flags.quantType = schema::QuantType_QUANT_NONE;
  354. auto fp32_sm = CreateSessionByFuncGraph(func_graph, flags, config_param_.thread_num);
  355. auto fp32_session = fp32_sm.session;
  356. auto fp32_model = fp32_sm.model;
  357. if (fp32_session == nullptr || fp32_model == nullptr) {
  358. MS_LOG(ERROR) << "CreateSessoin fail";
  359. delete fp32_model;
  360. return RET_ERROR;
  361. }
  362. auto fp32_inputs = fp32_session->GetInputs();
  363. fp32_output_tensors_.resize(image_cnt);
  364. // 0.3 save fp32 output
  365. for (size_t i = 0; i < image_cnt; i++) {
  366. if (!config_param_.input_shapes.empty()) {
  367. auto status = fp32_session->Resize(fp32_inputs, {config_param_.input_shapes[i]});
  368. if (status != RET_OK) {
  369. MS_LOG(ERROR) << "session Resize fail";
  370. delete fp32_sm.session;
  371. delete fp32_sm.model;
  372. return RET_ERROR;
  373. }
  374. }
  375. for (size_t input_index = 0; input_index < fp32_inputs.size(); input_index++) {
  376. auto status = CopyInputDataToTensor(input_index, i, images_, fp32_inputs[input_index]);
  377. if (status != RET_OK) {
  378. MS_LOG(ERROR) << "generate input data from images failed!";
  379. delete fp32_sm.session;
  380. delete fp32_sm.model;
  381. return RET_ERROR;
  382. }
  383. }
  384. auto status = fp32_session->RunGraph();
  385. if (status != RET_OK) {
  386. MS_LOG(ERROR) << "RunGraph fail";
  387. delete fp32_sm.session;
  388. delete fp32_sm.model;
  389. return RET_ERROR;
  390. }
  391. auto fp32_outputs = fp32_session->GetOutputs();
  392. for (const auto &kv : fp32_outputs) {
  393. auto *tensor = kv.second;
  394. auto *lite_tensor = reinterpret_cast<lite::Tensor *>(tensor);
  395. if (lite_tensor == nullptr) {
  396. MS_LOG(ERROR) << "not lite tensor";
  397. delete fp32_sm.session;
  398. delete fp32_sm.model;
  399. return RET_ERROR;
  400. }
  401. auto *new_tensor = Tensor::CopyTensor(*lite_tensor, true);
  402. fp32_output_tensors_[i][kv.first] = new_tensor;
  403. }
  404. }
  405. delete fp32_sm.session;
  406. delete fp32_sm.model;
  407. return RET_OK;
  408. }
  409. STATUS WeightQuantizer::DoMixedQuantize(const FuncGraphPtr &func_graph) {
  410. auto cnodes = func_graph->GetOrderedCnodes();
  411. int status = RET_OK;
  412. for (auto &cnode : cnodes) {
  413. auto op_type = NodePrimitiveType(cnode);
  414. if (op_type == schema::PrimitiveType_Lstm) {
  415. status = DoLstmQuantize(cnode);
  416. if (status != RET_OK) {
  417. MS_LOG(ERROR) << "DoLstmQuantize error";
  418. return RET_ERROR;
  419. }
  420. } else if (op_type == schema::PrimitiveType_Gather) {
  421. status = DoGatherQuantize(cnode);
  422. if (status != RET_OK) {
  423. MS_LOG(ERROR) << "DoGatherQuantize error";
  424. return RET_ERROR;
  425. }
  426. }
  427. }
  428. return status;
  429. }
  430. STATUS WeightQuantizer::CheckImageCnt() {
  431. auto image_cnt = images_.at(0).size();
  432. if (!config_param_.input_shapes.empty()) {
  433. if (config_param_.input_shapes.size() != image_cnt) {
  434. MS_LOG(ERROR) << "input_shapes size: " << config_param_.input_shapes.size() << " image_cnt: " << image_cnt;
  435. return RET_ERROR;
  436. }
  437. }
  438. return RET_OK;
  439. }
  440. STATUS WeightQuantizer::GetParamNodeAndValue(const std::shared_ptr<AnfNode> &input_node, const std::string &op_name,
  441. ParameterPtr *param_node, ParamValueLitePtr *param_value) {
  442. if (!input_node->isa<Parameter>()) {
  443. MS_LOG(WARNING) << op_name << " the second input is not parameter";
  444. return RET_CONTINUE;
  445. }
  446. *param_node = input_node->cast<ParameterPtr>();
  447. if (!(*param_node)->has_default()) {
  448. MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
  449. return RET_CONTINUE;
  450. }
  451. *param_value = std::static_pointer_cast<ParamValueLite>((*param_node)->default_param());
  452. if (*param_value == nullptr) {
  453. MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
  454. return RET_CONTINUE;
  455. }
  456. if ((*param_value)->tensor_type() != TypeId::kNumberTypeFloat32) {
  457. MS_LOG(WARNING) << op_name << " the second input type is not float";
  458. return RET_CONTINUE;
  459. }
  460. return RET_OK;
  461. }
  462. STATUS WeightQuantizer::TryQuant(const int &bit_num_t, const ParameterPtr &param_node,
  463. const ParamValueLitePtr &param_value, const std::shared_ptr<PrimitiveC> &primitive_c) {
  464. int status;
  465. type_id_ = TypeId::kNumberTypeInt8;
  466. int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
  467. int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
  468. if (type_id_ == TypeId::kNumberTypeInt8) {
  469. status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
  470. bit_num_t, true);
  471. } else if (type_id_ == TypeId::kNumberTypeInt16) {
  472. status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
  473. bit_num_t, true);
  474. } else {
  475. MS_LOG(ERROR) << "unexpected type_id_: " << type_id_;
  476. return RET_ERROR;
  477. }
  478. if (status == RET_CONTINUE) {
  479. return RET_OK;
  480. } else if (status != RET_OK) {
  481. MS_LOG(ERROR) << "quant filter failed.";
  482. return RET_ERROR;
  483. }
  484. status = SetAbstract(param_value, param_node, primitive_c);
  485. if (status != RET_OK) {
  486. MS_LOG(ERROR) << "SetAbstract failed : " << status;
  487. return RET_ERROR;
  488. }
  489. return status;
  490. }
  491. STATUS WeightQuantizer::DoQuantSearch(const FuncGraphPtr &func_graph) {
  492. auto cnodes = func_graph->GetOrderedCnodes();
  493. auto image_cnt = images_.at(0).size();
  494. int status = RET_OK;
  495. for (auto iter = cnodes.end(); iter != cnodes.begin();) {
  496. auto cnode = *(--iter);
  497. auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
  498. if (primitive_c == nullptr) {
  499. MS_LOG(ERROR) << "primitive_c is null.";
  500. return RET_ERROR;
  501. }
  502. auto op_name = cnode->fullname_with_scope();
  503. MS_LOG(DEBUG) << "process node: " << op_name
  504. << " type: " << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive_c->Type());
  505. if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) {
  506. auto input_node = cnode->input(2);
  507. ParameterPtr param_node;
  508. ParamValueLitePtr param_value;
  509. status = GetParamNodeAndValue(input_node, op_name, &param_node, &param_value);
  510. if (status == RET_CONTINUE) {
  511. continue;
  512. }
  513. // copy origin data in case to recover
  514. auto *raw_data = static_cast<float *>(param_value->tensor_addr());
  515. auto elem_count = param_value->tensor_shape_size();
  516. std::unique_ptr<float[]> origin_data(new (std::nothrow) float[elem_count]);
  517. auto ret = memcpy_s(origin_data.get(), sizeof(float) * elem_count, raw_data, param_value->tensor_size());
  518. if (ret != EOK) {
  519. MS_LOG(ERROR) << "memcpy fail: "
  520. << " dst size: " << sizeof(float) * elem_count << " src size: " << param_value->tensor_size();
  521. return RET_ERROR;
  522. }
  523. // 1. try quant
  524. for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) {
  525. status = TryQuant(bit_num_t, param_node, param_value, primitive_c);
  526. if (status != RET_OK) {
  527. MS_LOG(ERROR) << "TryQuant failed.";
  528. return RET_ERROR;
  529. }
  530. // 2. evaluate the quant
  531. // 2.1 create quant session, get input, output tensor
  532. flags.quantType = schema::QuantType_WeightQuant;
  533. auto quant_sm = CreateSessionByFuncGraph(func_graph, flags, config_param_.thread_num);
  534. auto quant_session = std::unique_ptr<session::LiteSession>(quant_sm.session);
  535. if (quant_session == nullptr) {
  536. MS_LOG(ERROR) << "create session error: " << status;
  537. delete quant_sm.model;
  538. return RET_ERROR;
  539. }
  540. auto quant_inputs = quant_session->GetInputs();
  541. auto mean_error = 0.0f;
  542. for (size_t i = 0; i < image_cnt; i++) {
  543. if (!config_param_.input_shapes.empty()) {
  544. status = quant_session->Resize(quant_inputs, {config_param_.input_shapes[i]});
  545. if (status != RET_OK) {
  546. MS_LOG(ERROR) << "session Resize fail";
  547. delete quant_sm.model;
  548. return RET_ERROR;
  549. }
  550. }
  551. // set multi-input data
  552. for (size_t input_index = 0; input_index < quant_inputs.size(); input_index++) {
  553. status = CopyInputDataToTensor(input_index, i, images_, quant_inputs[input_index]);
  554. if (status != RET_OK) {
  555. MS_LOG(ERROR) << "generate input data from images failed!";
  556. delete quant_sm.model;
  557. return RET_ERROR;
  558. }
  559. }
  560. status = quant_session->RunGraph();
  561. if (status != RET_OK) {
  562. MS_LOG(ERROR) << "quant session run error";
  563. delete quant_sm.model;
  564. return RET_ERROR;
  565. }
  566. // 3. compare between quant and fp32
  567. auto quant_outputs = quant_session->GetOutputs();
  568. mean_error += CompareOutputData<float>(fp32_output_tensors_[i], quant_outputs);
  569. } // end_for: calib data loop
  570. delete quant_sm.model;
  571. mean_error = mean_error / image_cnt;
  572. if (mean_error <= config_param_.mean_error_threshold) {
  573. MS_LOG(DEBUG) << "op: " << op_name << " got mixed bit: " << bit_num_t << " mean_error: " << mean_error;
  574. opname_bit_[op_name] = bit_num_t;
  575. break;
  576. } else if (bit_num_t != 8) {
  577. MS_LOG(DEBUG) << "op: " << op_name << " intermediate bit: " << bit_num_t << " mean_error: " << mean_error
  578. << " [recover]";
  579. // recover
  580. status = UpdateTensorDataAndSize(param_value, origin_data.get(), sizeof(float) * elem_count);
  581. if (status != RET_OK) {
  582. MS_LOG(ERROR) << "UpdateTensorDataAndSize fail";
  583. return RET_ERROR;
  584. }
  585. } else {
  586. MS_LOG(DEBUG) << "op: " << op_name << " set bit: " << bit_num_t << " mean_error: " << mean_error;
  587. opname_bit_[op_name] = bit_num_t;
  588. }
  589. } // end bit loop
  590. } // if: conv and matmul
  591. } // end loop: all cnode
  592. return status;
  593. }
  594. STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
  595. // 0.2 Parse input calib files
  596. auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_);
  597. if (status != RET_OK) {
  598. MS_LOG(ERROR) << "CollectCalibInputs failed.";
  599. return RET_ERROR;
  600. }
  601. status = RunFp32Graph(func_graph);
  602. if (status != RET_OK) {
  603. MS_LOG(ERROR) << "RunFp32Graph failed.";
  604. return RET_ERROR;
  605. }
  606. status = DoMixedQuantize(func_graph);
  607. if (status != RET_OK) {
  608. MS_LOG(ERROR) << "DoMixedQuantize failed.";
  609. return RET_ERROR;
  610. }
  611. status = CheckImageCnt();
  612. if (status != RET_OK) {
  613. MS_LOG(ERROR) << "CheckImageCnt failed.";
  614. return RET_ERROR;
  615. }
  616. status = DoQuantSearch(func_graph);
  617. if (status != RET_OK) {
  618. MS_LOG(ERROR) << "DoQuantSearch failed.";
  619. return RET_ERROR;
  620. }
  621. for (const auto &kv : opname_bit_) {
  622. MS_LOG(INFO) << "op: " << kv.first << " bit:" << kv.second;
  623. }
  624. return RET_OK;
  625. }
  626. STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) {
  627. MS_ASSERT(func_graph != nullptr);
  628. for (auto &cnode : func_graph->GetOrderedCnodes()) {
  629. auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
  630. if (primitive_c == nullptr) {
  631. MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive_c is nullptr";
  632. continue;
  633. }
  634. auto op_name = cnode->fullname_with_scope();
  635. auto op_type = (schema::PrimitiveType)primitive_c->Type();
  636. if (quant_strategy_->CanConvOpQuantized(cnode)) {
  637. auto status = DoConvQuantize(cnode);
  638. if (status != RET_OK) {
  639. MS_LOG(ERROR) << "DoConvQuantize error";
  640. return RET_ERROR;
  641. }
  642. } else if (quant_strategy_->CanMulOpQuantized(cnode)) {
  643. auto status = DoMulQuantize(cnode);
  644. if (status != RET_OK) {
  645. MS_LOG(ERROR) << "DoMulQuantize error";
  646. return RET_ERROR;
  647. }
  648. } else if (op_type == schema::PrimitiveType_Lstm) {
  649. auto status = DoLstmQuantize(cnode);
  650. if (status != RET_OK) {
  651. MS_LOG(ERROR) << "DoLstmQuantize error";
  652. return RET_ERROR;
  653. }
  654. } else if (op_type == schema::PrimitiveType_Gather) {
  655. auto status = DoGatherQuantize(cnode);
  656. if (status != RET_OK) {
  657. MS_LOG(ERROR) << "DoGatherQuantize error";
  658. return RET_ERROR;
  659. }
  660. } else {
  661. MS_LOG(DEBUG) << op_name << " of type: " << schema::EnumNamePrimitiveType(op_type) << " no need quant";
  662. }
  663. }
  664. return RET_OK;
  665. }
  666. STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
  667. MS_ASSERT(func_graph != nullptr);
  668. if (!config_file_.empty()) {
  669. auto ret = ParseConfigFile(config_file_, &config_param_);
  670. if (ret != RET_OK) {
  671. MS_LOG(ERROR) << "ReadConfig error.";
  672. return RET_ERROR;
  673. }
  674. }
  675. if (config_param_.mixed) {
  676. bit_num_ = 8;
  677. quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
  678. quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
  679. type_id_ = kNumberTypeInt8;
  680. MS_LOG(INFO) << "Do mixed bit quantization";
  681. return DoMixedQuant(func_graph);
  682. }
  683. return DoFixedQuant(func_graph);
  684. }
  685. } // namespace mindspore::lite::quant