Browse Source

!4430 fix post training quant

Merge pull request !4430 from xutianchun/quant_0814
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
89127ccf65
3 changed files with 25 additions and 23 deletions
  1. +1
    -1
      mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc
  2. +5
    -5
      mindspore/lite/tools/converter/converter.cc
  3. +19
    -17
      mindspore/lite/tools/converter/quantizer/quantize_util.cc

+ 1
- 1
mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc View File

@@ -89,7 +89,7 @@ int AnfImporterFromMetaGraphT::ConverterCNode() {
} }
auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release()); auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release());
// add quant parameter // add quant parameter
if (cNode->quantType == schema::QuantType_AwareTrainning || cNode->quantType == schema::QuantType_PostTraining) {
if (cNode->quantType == schema::QuantType_AwareTrainning) {
primTValue->SetQuantType(cNode->quantType); primTValue->SetQuantType(cNode->quantType);
for (int index : cNode->inputIndex) { for (int index : cNode->inputIndex) {
primTValue->AddInputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0])); primTValue->AddInputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0]));


+ 5
- 5
mindspore/lite/tools/converter/converter.cc View File

@@ -141,11 +141,11 @@ void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *
// flags->bitNum)); // flags->bitNum));
// break; // break;
// } // }
// case mindspore::schema::QuantType_PostTraining: {
// MS_LOG(INFO) << "create PostTrainningQuantizer!";
// mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8));
// break;
// }
case mindspore::schema::QuantType_PostTraining: {
MS_LOG(INFO) << "create PostTrainningQuantizer!";
mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8));
break;
}
case mindspore::schema::QuantType_QUANT_NONE: case mindspore::schema::QuantType_QUANT_NONE:
MS_LOG(INFO) << "Not do quantization for model!"; MS_LOG(INFO) << "Not do quantization for model!";
break; break;


+ 19
- 17
mindspore/lite/tools/converter/quantizer/quantize_util.cc View File

@@ -308,21 +308,24 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl


STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum, STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum,
bool per_channel) { bool per_channel) {
if (per_channel) {
// per channel
auto dims = weightPtr->tensor_shape();
if (dims.size() < 1) {
MS_LOG(ERROR) << "weight dims size error";
return RET_ERROR;
}
// todo(x)
auto dims = weightPtr->tensor_shape();
if (dims.size() != 4) {
MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer.";
per_channel = false;
} else {
uint32_t channels = dims[3]; uint32_t channels = dims[3];
if (channels == 0) { if (channels == 0) {
MS_LOG(ERROR) << "channels error 0";
MS_LOG(ERROR) << "channels is 0";
return RET_ERROR; return RET_ERROR;
} }
}


if (per_channel) {
// notice:
// at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D
// if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK
size_t shapeSize = weightPtr->tensor_shape_size(); size_t shapeSize = weightPtr->tensor_shape_size();
auto channels = dims[3];
size_t oneFilterSize = shapeSize / channels; size_t oneFilterSize = shapeSize / channels;
auto *rawDatas = reinterpret_cast<const float *>(weightPtr->tensor_addr()); auto *rawDatas = reinterpret_cast<const float *>(weightPtr->tensor_addr());
if (rawDatas == nullptr) { if (rawDatas == nullptr) {
@@ -330,17 +333,17 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
return RET_ERROR; return RET_ERROR;
} }


float min = FLT_MAX;
float max = FLT_MIN;
weightPtr->quant_param().clear(); weightPtr->quant_param().clear();
vector<int8_t> qDatas(shapeSize); vector<int8_t> qDatas(shapeSize);

for (uint32_t i = 0; i < channels; i++) { for (uint32_t i = 0; i < channels; i++) {
float min = 0;
float max = 0;
// find min and max // find min and max
for (uint32_t j = 0; j < oneFilterSize; j++) { for (uint32_t j = 0; j < oneFilterSize; j++) {
min = std::min(min, rawDatas[j + i * oneFilterSize]);
max = std::max(max, rawDatas[j + i * oneFilterSize]);
min = std::min(min, rawDatas[i + j * oneFilterSize]);
max = std::max(max, rawDatas[i + j * oneFilterSize]);
} }

std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam); std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam);
STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum);
if (status != RET_OK) { if (status != RET_OK) {
@@ -349,11 +352,10 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
} }
// update data and datatype // update data and datatype
for (uint32_t j = 0; j < oneFilterSize; j++) { for (uint32_t j = 0; j < oneFilterSize; j++) {
float rawData = rawDatas[j + i * oneFilterSize];
float rawData = rawDatas[i + j * oneFilterSize];
auto qData = QuantizeData<int8_t>(rawData, quantParam.get(), quant_max, quant_min); auto qData = QuantizeData<int8_t>(rawData, quantParam.get(), quant_max, quant_min);
qDatas[j + i * oneFilterSize] = qData;
qDatas[i + j * oneFilterSize] = qData;
} }

weightPtr->set_quant_param(quantParam); weightPtr->set_quant_param(quantParam);
} }
auto ret = auto ret =


Loading…
Cancel
Save