|
|
|
@@ -106,7 +106,7 @@ STATUS AwareQuantizer::GenerateQuantParam() { |
|
|
|
for (auto graphInputIndex : graph->inputIndex) { |
|
|
|
auto status = mInputArray->SetInputArrayQP(graph, graphInputIndex); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "SetInputArrayQP failed"; |
|
|
|
MS_LOG(WARNING) << "SetInputArrayQP failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -121,8 +121,8 @@ STATUS AwareQuantizer::GenerateQuantParam() { |
|
|
|
} |
|
|
|
auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); |
|
|
|
if (quantParamCalcer == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str() |
|
|
|
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; |
|
|
|
MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str() |
|
|
|
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; |
|
|
|
node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE); |
|
|
|
} else { |
|
|
|
auto status = quantParamCalcer->Calc(graph, *node); |
|
|
|
@@ -154,7 +154,7 @@ STATUS AwareQuantizer::DoQuantize() { |
|
|
|
GetCNodeTType(*node) == schema::PrimitiveType_MatMul) { |
|
|
|
auto inputIndexes = node->inputIndex; |
|
|
|
if (inputIndexes.size() < 2) { |
|
|
|
MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count"; |
|
|
|
MS_LOG(WARNING) << node->name.c_str() << " node input has invalid inputs tensor count"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
// quant weight |
|
|
|
@@ -162,7 +162,7 @@ STATUS AwareQuantizer::DoQuantize() { |
|
|
|
if (!weightTensor->quantParams.empty() && weightTensor->quantParams.at(0)->inited) { |
|
|
|
status = QuantConvWeight(graph, node.get()); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "QuantConvWeight failed!"; |
|
|
|
MS_LOG(WARNING) << "QuantConvWeight failed!"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -172,7 +172,7 @@ STATUS AwareQuantizer::DoQuantize() { |
|
|
|
if (!biasTensor->quantParams.empty() && biasTensor->quantParams.at(0)->inited) { |
|
|
|
status = QuantConvBias(graph, node.get()); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "QuantConvBias failed!"; |
|
|
|
MS_LOG(WARNING) << "QuantConvBias failed!"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -180,13 +180,15 @@ STATUS AwareQuantizer::DoQuantize() { |
|
|
|
} else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) { |
|
|
|
status = QuantDetectionPostProcessConstTensor(graph, node.get()); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!"; |
|
|
|
MS_LOG(WARNING) << "QuantDetectionPostProcessConstTensor failed!"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} else if (GetCNodeTType(*node) == schema::PrimitiveType_Add) { |
|
|
|
status = QuantAddConstTensor(graph, node.get()); |
|
|
|
} else if (GetCNodeTType(*node) == schema::PrimitiveType_Add || |
|
|
|
GetCNodeTType(*node) == schema::PrimitiveType_Scale || |
|
|
|
GetCNodeTType(*node) == schema::PrimitiveType_Mul) { |
|
|
|
status = QuantArithmeticConstTensor(graph, node.get()); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "QuantAddConstTensor failed!"; |
|
|
|
MS_LOG(WARNING) << "QuantArithmeticConstTensor failed!"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -203,7 +205,7 @@ STATUS AwareQuantizer::DoQuantize() { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) { |
|
|
|
STATUS AwareQuantizer::QuantArithmeticConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) { |
|
|
|
MS_ASSERT(graph != nullptr); |
|
|
|
MS_ASSERT(node != nullptr); |
|
|
|
for (size_t i = 0; i < node->inputIndex.size(); i++) { |
|
|
|
@@ -211,28 +213,40 @@ STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, sche |
|
|
|
MS_ASSERT(graph->allTensors.size() > inTensorIdx); |
|
|
|
auto &inTensor = graph->allTensors.at(inTensorIdx); |
|
|
|
MS_ASSERT(inTensor != nullptr); |
|
|
|
if (inTensor->refCount == 999) { |
|
|
|
switch (inTensor->dataType) { |
|
|
|
case TypeId::kNumberTypeFloat: { |
|
|
|
auto quantParam = GetTensorQuantParam(inTensor); |
|
|
|
MS_ASSERT(quantParam != nullptr); |
|
|
|
MS_ASSERT(quantParam->inited); |
|
|
|
auto constTensorShapeSize = GetShapeSize(*(inTensor.get())); |
|
|
|
vector<uint8_t> qDatas(constTensorShapeSize); |
|
|
|
void *inData = inTensor->data.data(); |
|
|
|
auto *castedInData = static_cast<float *>(inData); |
|
|
|
for (size_t j = 0; j < constTensorShapeSize; j++) { |
|
|
|
qDatas[j] = QuantizeData<uint8_t>(castedInData[j], quantParam.get()); |
|
|
|
} |
|
|
|
inTensor->data = std::move(qDatas); |
|
|
|
inTensor->dataType = kNumberTypeUInt8; |
|
|
|
} break; |
|
|
|
case kNumberTypeUInt8: |
|
|
|
break; |
|
|
|
default: |
|
|
|
MS_LOG(ERROR) << "Unsupported dataType: " << inTensor->dataType; |
|
|
|
return RET_ERROR; |
|
|
|
if (!inTensor->data.empty()) { |
|
|
|
if (inTensor->dataType == TypeId::kNumberTypeInt8) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (inTensor->dataType != TypeId::kNumberTypeFloat32 && inTensor->dataType != TypeId::kNumberTypeFloat && |
|
|
|
inTensor->dataType != TypeId::kNumberTypeUInt8) { |
|
|
|
MS_LOG(WARNING) << node->name.c_str() << "'s weight data is not float or uint8"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
auto quantParam = GetTensorQuantParam(inTensor); |
|
|
|
MS_ASSERT(quantParam != nullptr); |
|
|
|
MS_ASSERT(quantParam->inited); |
|
|
|
auto constTensorShapeSize = GetShapeSize(*(inTensor.get())); |
|
|
|
vector<int8_t> qDatas(constTensorShapeSize); |
|
|
|
void *inData = inTensor->data.data(); |
|
|
|
if (inTensor->dataType == TypeId::kNumberTypeFloat || |
|
|
|
inTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant |
|
|
|
auto *weightData = static_cast<float *>(inData); |
|
|
|
for (size_t j = 0; j < constTensorShapeSize; j++) { |
|
|
|
qDatas[j] = QuantizeData<int8_t>(weightData[j], quantParam.get()); |
|
|
|
} |
|
|
|
} else { // tflite awareing quant |
|
|
|
auto *weightData = static_cast<uint8_t *>(inData); |
|
|
|
for (size_t j = 0; j < constTensorShapeSize; j++) { |
|
|
|
qDatas[j] = (int32_t)weightData[j] - 128; |
|
|
|
} |
|
|
|
quantParam->zeroPoint -= 128; |
|
|
|
inTensor->quantParams.clear(); |
|
|
|
inTensor->quantParams.emplace_back(quantParam.release()); |
|
|
|
} |
|
|
|
|
|
|
|
::memcpy(inTensor->data.data(), qDatas.data(), constTensorShapeSize); |
|
|
|
inTensor->dataType = TypeId::kNumberTypeInt8; |
|
|
|
} |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
@@ -245,21 +259,21 @@ STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGr |
|
|
|
MS_ASSERT(constTensor != nullptr); |
|
|
|
const auto *constData = reinterpret_cast<const float *>(constTensor->data.data()); |
|
|
|
|
|
|
|
if (constTensor->nodeType == schema::NodeType::NodeType_ValueNode && |
|
|
|
constTensor->dataType == TypeId::kNumberTypeFloat) { |
|
|
|
if (!constTensor->data.empty() && |
|
|
|
(constTensor->dataType == TypeId::kNumberTypeFloat || constTensor->dataType == TypeId::kNumberTypeFloat32)) { |
|
|
|
size_t constTensorShapeSize = GetShapeSize(*constTensor); |
|
|
|
std::unique_ptr<QuantParamT> quantParam = GetTensorQuantParam(constTensor); |
|
|
|
if (quantParam == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new QuantParamT failed"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
vector<uint8_t> qDatas(constTensorShapeSize); |
|
|
|
vector<int8_t> qDatas(constTensorShapeSize); |
|
|
|
for (size_t j = 0; j < constTensorShapeSize; j++) { |
|
|
|
float rawData = constData[j]; |
|
|
|
qDatas[j] = QuantizeData<uint8_t>(rawData, quantParam.get()); |
|
|
|
qDatas[j] = QuantizeData<int8_t>(rawData, quantParam.get()); |
|
|
|
} |
|
|
|
constTensor->data = std::move(qDatas); |
|
|
|
constTensor->dataType = TypeId::kNumberTypeUInt8; |
|
|
|
::memcpy(constTensor->data.data(), qDatas.data(), constTensorShapeSize); |
|
|
|
constTensor->dataType = TypeId::kNumberTypeInt8; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
@@ -340,13 +354,14 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schem |
|
|
|
} |
|
|
|
if (weightTensor->dataType != TypeId::kNumberTypeFloat32 && weightTensor->dataType != TypeId::kNumberTypeFloat && |
|
|
|
weightTensor->dataType != TypeId::kNumberTypeUInt8) { |
|
|
|
MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8"; |
|
|
|
MS_LOG(WARNING) << "conv " << node->name.c_str() << "'s weight data is not float or uint8"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
size_t wShapeSize = GetShapeSize(*(weightTensor.get())); |
|
|
|
void *oriWeightData = weightTensor->data.data(); |
|
|
|
MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr); |
|
|
|
vector<int8_t> qDatas(wShapeSize); |
|
|
|
// todo support perchannel |
|
|
|
auto weightQauntParam = GetTensorQuantParam(weightTensor); |
|
|
|
if (weightTensor->dataType == TypeId::kNumberTypeFloat || |
|
|
|
weightTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant |
|
|
|
|