|
|
|
@@ -293,13 +293,13 @@ STATUS AwareQuantizer::GenerateQuantParam() { |
|
|
|
MS_ASSERT(graph->inputIndex.size() == 1); |
|
|
|
// set graphInputNode input |
|
|
|
for (auto graphInputIndex : graph->inputIndex) { |
|
|
|
auto status = mInputArray->SetInputArrayQP(graph.get(), graphInputIndex); |
|
|
|
auto status = mInputArray->SetInputArrayQP(graph, graphInputIndex); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "SetInputArrayQP failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
auto status = GenerateDefaultQuantParam(graph.get()); |
|
|
|
auto status = GenerateDefaultQuantParam(graph); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "GenerateDefaultQuantParam failed"; |
|
|
|
return status; |
|
|
|
@@ -319,7 +319,7 @@ STATUS AwareQuantizer::GenerateQuantParam() { |
|
|
|
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; |
|
|
|
node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE); |
|
|
|
} else { |
|
|
|
status = quantParamCalcer->Calc(graph.get(), *node); |
|
|
|
status = quantParamCalcer->Calc(graph, *node); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); |
|
|
|
node->quantType = schema::QuantType_QUANT_NONE; |
|
|
|
@@ -349,27 +349,27 @@ STATUS AwareQuantizer::DoQuantize() { |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
// quant weight |
|
|
|
status = QuantConvWeight(graph.get(), node.get()); |
|
|
|
status = QuantConvWeight(graph, node.get()); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "QuantConvWeight failed!"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
// quant bias |
|
|
|
if (inputIndexes.size() == 3) { |
|
|
|
status = QuantConvBias(graph.get(), node.get()); |
|
|
|
status = QuantConvBias(graph, node.get()); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "QuantConvBias failed!"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
} else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) { |
|
|
|
status = QuantDetectionPostProcessConstTensor(graph.get(), node.get()); |
|
|
|
status = QuantDetectionPostProcessConstTensor(graph, node.get()); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} else if (GetCNodeTType(*node) == schema::PrimitiveType_Add) { |
|
|
|
status = QuantAddConstTensor(graph.get(), node.get()); |
|
|
|
status = QuantAddConstTensor(graph, node.get()); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "QuantAddConstTensor failed!"; |
|
|
|
return RET_ERROR; |
|
|
|
|