You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

aware_quantizer.cc 6.0 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/converter/quantizer/aware_quantizer.h"
  17. #include <cmath>
  18. #include <memory>
  19. #include <string>
  20. #include <utility>
  21. #include <vector>
  22. #include "schema/inner/model_generated.h"
  23. #include "securec/include/securec.h"
  24. #include "src/common/utils.h"
  25. #include "tools/common/node_util.h"
  26. #include "tools/common/tensor_util.h"
  27. #include "tools/converter/quantizer/calc_quant_param.h"
  28. #include "src/common/log_adapter.h"
  29. using std::string;
  30. using std::vector;
  31. namespace mindspore::lite::quant {
  32. AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType) : FbQuantizer(graph) {}
  33. STATUS AwareQuantizer::RemoveFakeQuant() { return RET_OK; }
  34. STATUS AwareQuantizer::GenerateQuantParam() {
  35. auto *quantParamRegister = QuantParamCalcRegister::GetInstance();
  36. for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
  37. auto &node = *iter;
  38. MS_ASSERT(node != nullptr);
  39. if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax ||
  40. GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) {
  41. MS_ASSERT(false);
  42. }
  43. auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
  44. if (quantParamCalcer == nullptr) {
  45. MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str()
  46. << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
  47. node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE);
  48. } else {
  49. auto status = quantParamCalcer->Calc(graph, *node);
  50. if (status != RET_OK) {
  51. MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
  52. node->quantType = schema::QuantType_QUANT_NONE;
  53. } else {
  54. node->quantType = schema::QuantType_AwareTraining;
  55. }
  56. }
  57. }
  58. return RET_OK;
  59. }
  60. STATUS AwareQuantizer::DoQuantize() {
  61. for (auto &tensor : graph->allTensors) {
  62. if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited || tensor->data.empty()) {
  63. continue;
  64. }
  65. if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat &&
  66. tensor->dataType != TypeId::kNumberTypeUInt8) {
  67. continue;
  68. }
  69. // perlayer
  70. if (tensor->quantParams.size() == 1) {
  71. auto &quantParam = tensor->quantParams.front();
  72. size_t wShapeSize = GetShapeSize(*(tensor.get()));
  73. void *oriWeightData = tensor->data.data();
  74. if (quantParam->dstDtype == TypeId::kNumberTypeInt8) {
  75. vector<int8_t> qDatas(wShapeSize);
  76. auto weightQauntParam = GetTensorQuantParam(tensor);
  77. if (tensor->dataType == TypeId::kNumberTypeFloat ||
  78. tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant
  79. auto *weightData = static_cast<float *>(oriWeightData);
  80. for (size_t j = 0; j < wShapeSize; j++) {
  81. qDatas[j] = QuantizeData<int8_t>(weightData[j], weightQauntParam.get());
  82. }
  83. } else { // tflite awareing quant
  84. auto *weightData = static_cast<uint8_t *>(oriWeightData);
  85. for (size_t j = 0; j < wShapeSize; j++) {
  86. qDatas[j] = (int32_t)weightData[j] - 128;
  87. }
  88. weightQauntParam->zeroPoint -= 128;
  89. tensor->quantParams.clear();
  90. tensor->quantParams.emplace_back(weightQauntParam.release());
  91. }
  92. ::memcpy(tensor->data.data(), qDatas.data(), wShapeSize);
  93. } else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) {
  94. // quant bias data
  95. auto bShapeSize = GetShapeSize(*(tensor.get()));
  96. std::unique_ptr<int32_t[]> qDatas(new (std::nothrow) int32_t[bShapeSize]);
  97. if (qDatas == nullptr) {
  98. MS_LOG(ERROR) << "new qDatas failed";
  99. return RET_ERROR;
  100. }
  101. void *biasData = tensor->data.data();
  102. auto *rawDatas = static_cast<float *>(biasData);
  103. for (size_t i = 0; i < bShapeSize; ++i) {
  104. qDatas[i] = (int32_t)std::round(rawDatas[i] / quantParam->scale);
  105. }
  106. tensor->dataType = TypeId::kNumberTypeInt32;
  107. tensor->data.clear();
  108. tensor->data.resize(bShapeSize * sizeof(int32_t));
  109. auto ret =
  110. memcpy_s(tensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t));
  111. if (ret != EOK) {
  112. MS_LOG(ERROR) << "memcpy_s failed: " << ret;
  113. return RET_ERROR;
  114. }
  115. }
  116. } else { // pertensor
  117. }
  118. }
  119. return RET_OK;
  120. }
  121. STATUS AwareQuantizer::DetermineNodeQuantType() {
  122. MS_ASSERT(graph != nullptr);
  123. for (auto &node : graph->nodes) {
  124. MS_ASSERT(node != nullptr);
  125. bool canQuant = true;
  126. for (auto &outTensorIdx : node->outputIndex) {
  127. MS_ASSERT(graph->allTensors.size() > outTensorIdx);
  128. auto &outTensor = graph->allTensors.at(outTensorIdx);
  129. MS_ASSERT(outTensor != nullptr);
  130. if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr ||
  131. !outTensor->quantParams.front()->inited) {
  132. canQuant = false;
  133. break;
  134. }
  135. }
  136. if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*node))) {
  137. node->quantType = schema::QuantType_AwareTraining;
  138. } else {
  139. node->quantType = schema::QuantType_QUANT_NONE;
  140. }
  141. }
  142. return RET_OK;
  143. }
  144. } // namespace mindspore::lite::quant