You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

aware_quantizer.cc 16 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/converter/quantizer/aware_quantizer.h"
  17. #include <cmath>
  18. #include <memory>
  19. #include <string>
  20. #include <utility>
  21. #include <vector>
  22. #include "schema/inner/model_generated.h"
  23. #include "securec/include/securec.h"
  24. #include "src/common/utils.h"
  25. #include "tools/common/converter_op_utils.h"
  26. #include "tools/common/node_util.h"
  27. #include "tools/common/tensor_util.h"
  28. #include "tools/converter/quantizer/calc_quant_param.h"
  29. #include "utils/log_adapter.h"
  30. using std::string;
  31. using std::vector;
  32. namespace mindspore::lite::quant {
  33. const std::array<schema::PrimitiveType, 7> AwareQuantizer::propagatedOps = {
  34. {schema::PrimitiveType_Concat, schema::PrimitiveType_Resize, schema::PrimitiveType_Reshape,
  35. schema::PrimitiveType_Squeeze, schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation,
  36. schema::PrimitiveType_DetectionPostProcess}};
  37. STATUS InputArray::InitQuantParam() {
  38. this->quantParam = std::make_unique<schema::QuantParamT>();
  39. auto status = CalQuantizationParams(quantParam.get(), mMin, mMax, narrowRange, numBits);
  40. if (status != RET_OK) {
  41. return status;
  42. }
  43. return RET_OK;
  44. }
  45. STATUS InputArray::SetInputArrayQP(schema::MetaGraphT *graph, size_t inputTensorIdx) {
  46. MS_ASSERT(graph != nullptr);
  47. auto &tensor = graph->allTensors.at(inputTensorIdx);
  48. MS_ASSERT(tensor != nullptr);
  49. if (!tensor->quantParams.empty()) {
  50. auto param = GetTensorQuantParam(tensor);
  51. if (param != nullptr && param->inited) {
  52. MS_LOG(DEBUG) << "tensor " << inputTensorIdx << " already has quantParam";
  53. return RET_OK;
  54. }
  55. tensor->quantParams.clear();
  56. }
  57. std::unique_ptr<schema::QuantParamT> tmpQuantParam(new QuantParamT());
  58. tmpQuantParam->inited = this->quantParam->inited;
  59. tmpQuantParam->scale = this->quantParam->scale;
  60. tmpQuantParam->zeroPoint = this->quantParam->zeroPoint;
  61. tmpQuantParam->min = this->quantParam->min;
  62. tmpQuantParam->max = this->quantParam->max;
  63. tensor->quantParams.push_back(std::move(tmpQuantParam));
  64. return RET_OK;
  65. }
  66. AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType, const string &stdValues,
  67. const string &meanValues)
  68. : FbQuantizer(graph) {
  69. MS_ASSERT(graph != nullptr);
  70. string::size_type sz;
  71. const float stdValue = std::stof(stdValues, &sz);
  72. sz = 0;
  73. const float mean = std::stof(meanValues, &sz);
  74. std::unique_ptr<InputArray> inArr = nullptr;
  75. if (inferType == kNumberTypeFloat) {
  76. inArr.reset(new (std::nothrow) InputArray(mean, stdValue));
  77. } else {
  78. inArr.reset(new (std::nothrow) InputArray(mean, stdValue, TypeId::kNumberTypeInt8));
  79. }
  80. mInputArray = inArr.get();
  81. mInputArray->InitQuantParam();
  82. }
  83. STATUS AwareQuantizer::RemoveFakeQuant() { return RET_OK; }
  84. STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph) {
  85. MS_ASSERT(subGraph != nullptr);
  86. for (const auto &tensor : subGraph->allTensors) {
  87. if (!tensor->quantParams.empty()) {
  88. continue;
  89. }
  90. std::unique_ptr<schema::QuantParamT> defaultQuantParam(new QuantParamT());
  91. tensor->quantParams.emplace_back(std::move(defaultQuantParam));
  92. }
  93. return RET_OK;
  94. }
  95. STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { return RET_OK; }
  96. STATUS AwareQuantizer::GenerateQuantParam() {
  97. MS_ASSERT(graph->inputIndex.size() == 1);
  98. // set graphInputNode input
  99. for (auto graphInputIndex : graph->inputIndex) {
  100. auto status = mInputArray->SetInputArrayQP(graph, graphInputIndex);
  101. if (status != RET_OK) {
  102. MS_LOG(ERROR) << "SetInputArrayQP failed";
  103. return status;
  104. }
  105. }
  106. auto *quantParamRegister = QuantParamCalcRegister::GetInstance();
  107. for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
  108. auto &node = *iter;
  109. MS_ASSERT(node != nullptr);
  110. if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax ||
  111. GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) {
  112. MS_ASSERT(false);
  113. }
  114. auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
  115. if (quantParamCalcer == nullptr) {
  116. MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str()
  117. << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
  118. node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE);
  119. } else {
  120. auto status = quantParamCalcer->Calc(graph, *node);
  121. if (status != RET_OK) {
  122. MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
  123. node->quantType = schema::QuantType_QUANT_NONE;
  124. } else {
  125. node->quantType = schema::QuantType_AwareTraining;
  126. }
  127. }
  128. }
  129. return RET_OK;
  130. }
  131. STATUS AwareQuantizer::DoQuantize() {
  132. for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
  133. auto &node = *iter;
  134. if (!IsContain(GetUint8OpList(), GetCNodeTType(*node))) {
  135. continue;
  136. }
  137. if (node->quantType != schema::QuantType_AwareTraining) {
  138. continue;
  139. }
  140. STATUS status;
  141. if (GetCNodeTType(*node) == schema::PrimitiveType_Conv2D ||
  142. GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D ||
  143. GetCNodeTType(*node) == schema::PrimitiveType_FullConnection ||
  144. GetCNodeTType(*node) == schema::PrimitiveType_MatMul) {
  145. auto inputIndexes = node->inputIndex;
  146. if (inputIndexes.size() < 2) {
  147. MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count";
  148. return RET_ERROR;
  149. }
  150. // quant weight
  151. auto &weightTensor = graph->allTensors.at(node->inputIndex.at(1));
  152. if (!weightTensor->quantParams.empty() && weightTensor->quantParams.at(0)->inited) {
  153. status = QuantConvWeight(graph, node.get());
  154. if (status != RET_OK) {
  155. MS_LOG(ERROR) << "QuantConvWeight failed!";
  156. return RET_ERROR;
  157. }
  158. }
  159. // quant bias
  160. if (inputIndexes.size() == 3) {
  161. auto &biasTensor = graph->allTensors.at(node->inputIndex.at(2));
  162. if (!biasTensor->quantParams.empty() && biasTensor->quantParams.at(0)->inited) {
  163. status = QuantConvBias(graph, node.get());
  164. if (status != RET_OK) {
  165. MS_LOG(ERROR) << "QuantConvBias failed!";
  166. return RET_ERROR;
  167. }
  168. }
  169. }
  170. } else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) {
  171. status = QuantDetectionPostProcessConstTensor(graph, node.get());
  172. if (status != RET_OK) {
  173. MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!";
  174. return RET_ERROR;
  175. }
  176. } else if (GetCNodeTType(*node) == schema::PrimitiveType_Add) {
  177. status = QuantAddConstTensor(graph, node.get());
  178. if (status != RET_OK) {
  179. MS_LOG(ERROR) << "QuantAddConstTensor failed!";
  180. return RET_ERROR;
  181. }
  182. }
  183. const auto nodeType = GetCNodeTType(*node);
  184. auto find = std::find(propagatedOps.begin(), propagatedOps.end(), nodeType);
  185. if (find != propagatedOps.end()) {
  186. auto inputTensor = graph->allTensors.at(node->inputIndex[0]).get();
  187. auto outputTensor = graph->allTensors.at(node->outputIndex[0]).get();
  188. MS_ASSERT(inputTensor != nullptr);
  189. MS_ASSERT(outputTensor != nullptr);
  190. outputTensor->dataType = inputTensor->dataType;
  191. }
  192. }
  193. return RET_OK;
  194. }
  195. STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) {
  196. MS_ASSERT(graph != nullptr);
  197. MS_ASSERT(node != nullptr);
  198. for (size_t i = 0; i < node->inputIndex.size(); i++) {
  199. auto inTensorIdx = node->inputIndex.at(i);
  200. MS_ASSERT(graph->allTensors.size() > inTensorIdx);
  201. auto &inTensor = graph->allTensors.at(inTensorIdx);
  202. MS_ASSERT(inTensor != nullptr);
  203. if (inTensor->refCount == 999) {
  204. switch (inTensor->dataType) {
  205. case TypeId::kNumberTypeFloat: {
  206. auto quantParam = GetTensorQuantParam(inTensor);
  207. MS_ASSERT(quantParam != nullptr);
  208. MS_ASSERT(quantParam->inited);
  209. auto constTensorShapeSize = GetShapeSize(*(inTensor.get()));
  210. vector<uint8_t> qDatas(constTensorShapeSize);
  211. void *inData = inTensor->data.data();
  212. auto *castedInData = static_cast<float *>(inData);
  213. for (size_t j = 0; j < constTensorShapeSize; j++) {
  214. qDatas[j] = QuantizeData<uint8_t>(castedInData[j], quantParam.get());
  215. }
  216. inTensor->data = std::move(qDatas);
  217. inTensor->dataType = kNumberTypeUInt8;
  218. } break;
  219. case kNumberTypeUInt8:
  220. break;
  221. default:
  222. MS_LOG(ERROR) << "Unsupported dataType: " << inTensor->dataType;
  223. return RET_ERROR;
  224. }
  225. }
  226. }
  227. return RET_OK;
  228. }
  229. STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
  230. MS_ASSERT(subGraph != nullptr);
  231. MS_ASSERT(node != nullptr);
  232. auto &constTensor = subGraph->allTensors.at(node->inputIndex[2]);
  233. MS_ASSERT(constTensor != nullptr);
  234. const auto *constData = reinterpret_cast<const float *>(constTensor->data.data());
  235. if (constTensor->nodeType == schema::NodeType::NodeType_ValueNode &&
  236. constTensor->dataType == TypeId::kNumberTypeFloat) {
  237. size_t constTensorShapeSize = GetShapeSize(*constTensor);
  238. std::unique_ptr<QuantParamT> quantParam = GetTensorQuantParam(constTensor);
  239. if (quantParam == nullptr) {
  240. MS_LOG(ERROR) << "new QuantParamT failed";
  241. return RET_NULL_PTR;
  242. }
  243. vector<uint8_t> qDatas(constTensorShapeSize);
  244. for (size_t j = 0; j < constTensorShapeSize; j++) {
  245. float rawData = constData[j];
  246. qDatas[j] = QuantizeData<uint8_t>(rawData, quantParam.get());
  247. }
  248. constTensor->data = std::move(qDatas);
  249. constTensor->dataType = TypeId::kNumberTypeUInt8;
  250. }
  251. return RET_OK;
  252. }
  253. STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, mindspore::schema::CNodeT *node) {
  254. MS_ASSERT(graph != nullptr);
  255. MS_ASSERT(node != nullptr);
  256. auto inputIndexes = node->inputIndex;
  257. MS_ASSERT(inputIndexes.size() >= 3);
  258. MS_ASSERT(graph->allTensors.size() > inputIndexes.at(0));
  259. MS_ASSERT(graph->allTensors.size() > inputIndexes.at(1));
  260. MS_ASSERT(graph->allTensors.size() > inputIndexes.at(2));
  261. auto &biasTensor = graph->allTensors.at(inputIndexes.at(2));
  262. MS_ASSERT(biasTensor != nullptr);
  263. if (biasTensor->dataType == TypeId::kNumberTypeInt32) {
  264. return RET_OK;
  265. }
  266. if (biasTensor->dataType != TypeId::kNumberTypeFloat && biasTensor->dataType != TypeId::kNumberTypeFloat32) {
  267. MS_LOG(ERROR) << "conv " << node->name << "'s bias data is not float";
  268. return RET_ERROR;
  269. }
  270. auto &inputTensor = graph->allTensors.at(inputIndexes.at(0));
  271. auto &weightTensor = graph->allTensors.at(inputIndexes.at(1));
  272. MS_ASSERT(inputTensor != nullptr);
  273. MS_ASSERT(weightTensor != nullptr);
  274. auto inputScale = inputTensor->quantParams.front()->scale;
  275. auto weightScale = weightTensor->quantParams.front()->scale;
  276. auto scale = inputScale * weightScale;
  277. // set bias quant param
  278. std::unique_ptr<QuantParamT> biasQuantParam = GetTensorQuantParam(biasTensor);
  279. if (biasQuantParam == nullptr) {
  280. MS_LOG(ERROR) << "new QuantParamT failed";
  281. return RET_ERROR;
  282. }
  283. biasQuantParam->inited = true;
  284. biasQuantParam->scale = scale;
  285. biasQuantParam->zeroPoint = 0;
  286. biasQuantParam->numBits = 8;
  287. biasQuantParam->narrowRange = false;
  288. biasQuantParam->min = 0.0;
  289. biasQuantParam->max = 0.0;
  290. // quant bias data
  291. auto bShapeSize = GetShapeSize(*(biasTensor.get()));
  292. std::unique_ptr<int32_t[]> qDatas(new (std::nothrow) int32_t[bShapeSize]);
  293. if (qDatas == nullptr) {
  294. MS_LOG(ERROR) << "new qDatas failed";
  295. return RET_ERROR;
  296. }
  297. void *biasData = biasTensor->data.data();
  298. auto *rawDatas = static_cast<float *>(biasData);
  299. for (size_t i = 0; i < bShapeSize; ++i) {
  300. qDatas[i] = (int32_t)std::round(rawDatas[i] / scale);
  301. }
  302. biasTensor->dataType = TypeId::kNumberTypeInt32;
  303. biasTensor->data.clear();
  304. biasTensor->data.resize(bShapeSize * sizeof(int32_t));
  305. auto ret =
  306. memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t));
  307. if (ret != EOK) {
  308. MS_LOG(ERROR) << "memcpy_s failed: " << ret;
  309. return RET_ERROR;
  310. }
  311. return RET_OK;
  312. }
  313. STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
  314. MS_ASSERT(subGraph != nullptr);
  315. MS_ASSERT(node != nullptr);
  316. MS_ASSERT(node->quantParam.size() == node->inputIndex.size() + node->outputIndex.size());
  317. auto inputIndexes = node->inputIndex;
  318. MS_ASSERT(inputIndexes.size() >= 2);
  319. MS_ASSERT(subGraph->allTensors.size() > inputIndexes.at(1));
  320. auto &weightTensor = subGraph->allTensors.at(inputIndexes.at(1));
  321. if (weightTensor->dataType == TypeId::kNumberTypeInt8) {
  322. return RET_OK;
  323. }
  324. if (weightTensor->dataType != TypeId::kNumberTypeFloat32 && weightTensor->dataType != TypeId::kNumberTypeFloat &&
  325. weightTensor->dataType != TypeId::kNumberTypeUInt8) {
  326. MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8";
  327. return RET_ERROR;
  328. }
  329. size_t wShapeSize = GetShapeSize(*(weightTensor.get()));
  330. void *oriWeightData = weightTensor->data.data();
  331. MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr);
  332. vector<int8_t> qDatas(wShapeSize);
  333. auto weightQauntParam = GetTensorQuantParam(weightTensor);
  334. if (weightTensor->dataType == TypeId::kNumberTypeFloat ||
  335. weightTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant
  336. auto *weightData = static_cast<float *>(oriWeightData);
  337. for (size_t j = 0; j < wShapeSize; j++) {
  338. qDatas[j] = QuantizeData<int8_t>(weightData[j], weightQauntParam.get());
  339. }
  340. } else { // tflite awareing quant
  341. auto *weightData = static_cast<uint8_t *>(oriWeightData);
  342. for (size_t j = 0; j < wShapeSize; j++) {
  343. qDatas[j] = (int32_t)weightData[j] - 128;
  344. }
  345. weightQauntParam->zeroPoint -= 128;
  346. weightTensor->quantParams.clear();
  347. weightTensor->quantParams.emplace_back(weightQauntParam.release());
  348. }
  349. ::memcpy(weightTensor->data.data(), qDatas.data(), wShapeSize);
  350. weightTensor->dataType = TypeId::kNumberTypeInt8;
  351. return RET_OK;
  352. }
  353. STATUS AwareQuantizer::DetermineNodeQuantType() {
  354. MS_ASSERT(graph != nullptr);
  355. for (auto &node : graph->nodes) {
  356. MS_ASSERT(node != nullptr);
  357. bool canQuant = true;
  358. for (auto &outTensorIdx : node->outputIndex) {
  359. MS_ASSERT(graph->allTensors.size() > outTensorIdx);
  360. auto &outTensor = graph->allTensors.at(outTensorIdx);
  361. MS_ASSERT(outTensor != nullptr);
  362. if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr ||
  363. !outTensor->quantParams.front()->inited) {
  364. canQuant = false;
  365. break;
  366. }
  367. }
  368. if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) {
  369. node->quantType = schema::QuantType_AwareTraining;
  370. } else {
  371. node->quantType = schema::QuantType_QUANT_NONE;
  372. }
  373. }
  374. return RET_OK;
  375. }
  376. } // namespace mindspore::lite::quant