diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 1ad5189f12..9b5c452e45 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -79,7 +79,8 @@ static const std::vector int8OpList = { schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_TopK, schema::PrimitiveType_Unsqueeze, schema::PrimitiveType_MatMul, - schema::PrimitiveType_Pad, schema::PrimitiveType_DeConv2D}; + schema::PrimitiveType_Pad, schema::PrimitiveType_DeConv2D, + schema::PrimitiveType_Scale}; static const std::vector needInsertOpList = { schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc index 3ec4cf5f6a..5913d474c2 100644 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc @@ -106,7 +106,7 @@ STATUS AwareQuantizer::GenerateQuantParam() { for (auto graphInputIndex : graph->inputIndex) { auto status = mInputArray->SetInputArrayQP(graph, graphInputIndex); if (status != RET_OK) { - MS_LOG(ERROR) << "SetInputArrayQP failed"; + MS_LOG(WARNING) << "SetInputArrayQP failed"; return status; } } @@ -121,8 +121,8 @@ STATUS AwareQuantizer::GenerateQuantParam() { } auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); if (quantParamCalcer == nullptr) { - MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str() - << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; + MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str() + << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; node->quantType = static_cast(QuantType_QUANT_NONE); } else { auto status = quantParamCalcer->Calc(graph, *node); @@ -154,7 +154,7 @@ STATUS AwareQuantizer::DoQuantize() { GetCNodeTType(*node) == schema::PrimitiveType_MatMul) { auto inputIndexes = node->inputIndex; if (inputIndexes.size() < 2) { - MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count"; + MS_LOG(WARNING) << node->name.c_str() << " node input has invalid inputs tensor count"; return RET_ERROR; } // quant weight @@ -162,7 +162,7 @@ STATUS AwareQuantizer::DoQuantize() { if (!weightTensor->quantParams.empty() && weightTensor->quantParams.at(0)->inited) { status = QuantConvWeight(graph, node.get()); if (status != RET_OK) { - MS_LOG(ERROR) << "QuantConvWeight failed!"; + MS_LOG(WARNING) << "QuantConvWeight failed!"; return RET_ERROR; } } @@ -172,7 +172,7 @@ STATUS AwareQuantizer::DoQuantize() { if (!biasTensor->quantParams.empty() && biasTensor->quantParams.at(0)->inited) { status = QuantConvBias(graph, node.get()); if (status != RET_OK) { - MS_LOG(ERROR) << "QuantConvBias failed!"; + MS_LOG(WARNING) << "QuantConvBias failed!"; return RET_ERROR; } } @@ -180,13 +180,15 @@ STATUS AwareQuantizer::DoQuantize() { } else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) { status = QuantDetectionPostProcessConstTensor(graph, node.get()); if (status != RET_OK) { - MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!"; + MS_LOG(WARNING) << "QuantDetectionPostProcessConstTensor failed!"; return RET_ERROR; } - } else if (GetCNodeTType(*node) == schema::PrimitiveType_Add) { - status = QuantAddConstTensor(graph, node.get()); + } else if (GetCNodeTType(*node) == schema::PrimitiveType_Add || + GetCNodeTType(*node) == schema::PrimitiveType_Scale || + GetCNodeTType(*node) == schema::PrimitiveType_Mul) { + status = QuantArithmeticConstTensor(graph, node.get()); if (status != RET_OK) { - MS_LOG(ERROR) << "QuantAddConstTensor failed!"; + MS_LOG(WARNING) << "QuantArithmeticConstTensor failed!"; return RET_ERROR; } } @@ -203,7 +205,7 @@ STATUS AwareQuantizer::DoQuantize() { return RET_OK; } -STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) { +STATUS AwareQuantizer::QuantArithmeticConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) { MS_ASSERT(graph != nullptr); MS_ASSERT(node != nullptr); for (size_t i = 0; i < node->inputIndex.size(); i++) { @@ -211,28 +213,40 @@ STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, sche MS_ASSERT(graph->allTensors.size() > inTensorIdx); auto &inTensor = graph->allTensors.at(inTensorIdx); MS_ASSERT(inTensor != nullptr); - if (inTensor->refCount == 999) { - switch (inTensor->dataType) { - case TypeId::kNumberTypeFloat: { - auto quantParam = GetTensorQuantParam(inTensor); - MS_ASSERT(quantParam != nullptr); - MS_ASSERT(quantParam->inited); - auto constTensorShapeSize = GetShapeSize(*(inTensor.get())); - vector qDatas(constTensorShapeSize); - void *inData = inTensor->data.data(); - auto *castedInData = static_cast(inData); - for (size_t j = 0; j < constTensorShapeSize; j++) { - qDatas[j] = QuantizeData(castedInData[j], quantParam.get()); - } - inTensor->data = std::move(qDatas); - inTensor->dataType = kNumberTypeUInt8; - } break; - case kNumberTypeUInt8: - break; - default: - MS_LOG(ERROR) << "Unsupported dataType: " << inTensor->dataType; - return RET_ERROR; + if (!inTensor->data.empty()) { + if (inTensor->dataType == TypeId::kNumberTypeInt8) { + continue; } + if (inTensor->dataType != TypeId::kNumberTypeFloat32 && inTensor->dataType != TypeId::kNumberTypeFloat && + inTensor->dataType != TypeId::kNumberTypeUInt8) { + MS_LOG(WARNING) << node->name.c_str() << "'s weight data is not float or uint8"; + return RET_ERROR; + } + + auto quantParam = GetTensorQuantParam(inTensor); + MS_ASSERT(quantParam != nullptr); + MS_ASSERT(quantParam->inited); + auto constTensorShapeSize = GetShapeSize(*(inTensor.get())); + vector qDatas(constTensorShapeSize); + void *inData = inTensor->data.data(); + if (inTensor->dataType == TypeId::kNumberTypeFloat || + inTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant + auto *weightData = static_cast(inData); + for (size_t j = 0; j < constTensorShapeSize; j++) { + qDatas[j] = QuantizeData(weightData[j], quantParam.get()); + } + } else { // tflite awareing quant + auto *weightData = static_cast(inData); + for (size_t j = 0; j < constTensorShapeSize; j++) { + qDatas[j] = (int32_t)weightData[j] - 128; + } + quantParam->zeroPoint -= 128; + inTensor->quantParams.clear(); + inTensor->quantParams.emplace_back(quantParam.release()); + } + + ::memcpy(inTensor->data.data(), qDatas.data(), constTensorShapeSize); + inTensor->dataType = TypeId::kNumberTypeInt8; } } return RET_OK; @@ -245,21 +259,21 @@ STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGr MS_ASSERT(constTensor != nullptr); const auto *constData = reinterpret_cast(constTensor->data.data()); - if (constTensor->nodeType == schema::NodeType::NodeType_ValueNode && - constTensor->dataType == TypeId::kNumberTypeFloat) { + if (!constTensor->data.empty() && + (constTensor->dataType == TypeId::kNumberTypeFloat || constTensor->dataType == TypeId::kNumberTypeFloat32)) { size_t constTensorShapeSize = GetShapeSize(*constTensor); std::unique_ptr quantParam = GetTensorQuantParam(constTensor); if (quantParam == nullptr) { MS_LOG(ERROR) << "new QuantParamT failed"; return RET_NULL_PTR; } - vector qDatas(constTensorShapeSize); + vector qDatas(constTensorShapeSize); for (size_t j = 0; j < constTensorShapeSize; j++) { float rawData = constData[j]; - qDatas[j] = QuantizeData(rawData, quantParam.get()); + qDatas[j] = QuantizeData(rawData, quantParam.get()); } - constTensor->data = std::move(qDatas); - constTensor->dataType = TypeId::kNumberTypeUInt8; + ::memcpy(constTensor->data.data(), qDatas.data(), constTensorShapeSize); + constTensor->dataType = TypeId::kNumberTypeInt8; } return RET_OK; } @@ -340,13 +354,14 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schem } if (weightTensor->dataType != TypeId::kNumberTypeFloat32 && weightTensor->dataType != TypeId::kNumberTypeFloat && weightTensor->dataType != TypeId::kNumberTypeUInt8) { - MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8"; + MS_LOG(WARNING) << "conv " << node->name.c_str() << "'s weight data is not float or uint8"; return RET_ERROR; } size_t wShapeSize = GetShapeSize(*(weightTensor.get())); void *oriWeightData = weightTensor->data.data(); MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr); vector qDatas(wShapeSize); + // todo support perchannel auto weightQauntParam = GetTensorQuantParam(weightTensor); if (weightTensor->dataType == TypeId::kNumberTypeFloat || weightTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.h b/mindspore/lite/tools/converter/quantizer/aware_quantizer.h index f7038b1126..61d0318330 100644 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.h @@ -67,7 +67,7 @@ class AwareQuantizer : public FbQuantizer { STATUS GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph); - STATUS QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node); + STATUS QuantArithmeticConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node); STATUS QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node); diff --git a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc index 2370c5e5e8..44c7f2acc2 100644 --- a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc +++ b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc @@ -474,6 +474,7 @@ QuantParamCalcRegister::QuantParamCalcRegister() { _registerMap[schema::PrimitiveType_Activation] = std::make_shared(); _registerMap[schema::PrimitiveType_Add] = std::make_shared(); _registerMap[schema::PrimitiveType_Mul] = commonCalcer; + _registerMap[schema::PrimitiveType_Scale] = commonCalcer; _registerMap[schema::PrimitiveType_Conv2D] = commonCalcer; _registerMap[schema::PrimitiveType_DeConv2D] = commonCalcer; _registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer;