|
|
|
@@ -15,20 +15,22 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "tools/converter/quantizer/aware_quantizer.h" |
|
|
|
|
|
|
|
#include <cmath> |
|
|
|
#include <memory> |
|
|
|
#include <string> |
|
|
|
#include <utility> |
|
|
|
#include <vector> |
|
|
|
|
|
|
|
#include "schema/inner/model_generated.h" |
|
|
|
#include "utils/log_adapter.h" |
|
|
|
#include "securec/include/securec.h" |
|
|
|
#include "tools/converter/quantizer/quantize_util.h" |
|
|
|
#include "src/common/utils.h" |
|
|
|
#include "tools/converter/quantizer/calc_quant_param.h" |
|
|
|
#include "tools/common/tensor_util.h" |
|
|
|
#include "tools/common/converter_op_utils.h" |
|
|
|
#include "tools/common/node_util.h" |
|
|
|
#include "tools/common/tensor_util.h" |
|
|
|
#include "tools/converter/quantizer/calc_quant_param.h" |
|
|
|
#include "tools/converter/quantizer/quantize_util.h" |
|
|
|
#include "utils/log_adapter.h" |
|
|
|
|
|
|
|
using std::string; |
|
|
|
using std::vector; |
|
|
|
@@ -42,7 +44,8 @@ struct InputArray { |
|
|
|
int numBits = 8; |
|
|
|
TypeId dataType = TypeId::kTypeUnknown; |
|
|
|
|
|
|
|
InputArray(float mean, float stdDev, TypeId dataType = TypeId::kNumberTypeFloat) { |
|
|
|
InputArray(float mean, float stdDev, |
|
|
|
TypeId dataType = TypeId::kNumberTypeFloat) { |
|
|
|
this->dataType = dataType; |
|
|
|
constexpr float qmin = 0; |
|
|
|
constexpr float qmax = 255; |
|
|
|
@@ -52,7 +55,8 @@ struct InputArray { |
|
|
|
|
|
|
|
STATUS InitQuantParam() { |
|
|
|
this->quantParam = std::make_unique<schema::QuantParamT>(); |
|
|
|
auto status = CalQuantizationParams(quantParam.get(), mMin, mMax, narrowRange, numBits); |
|
|
|
auto status = CalQuantizationParams(quantParam.get(), mMin, mMax, |
|
|
|
narrowRange, numBits); |
|
|
|
if (status != RET_OK) { |
|
|
|
return status; |
|
|
|
} |
|
|
|
@@ -66,7 +70,8 @@ struct InputArray { |
|
|
|
if (!tensor->quantParams.empty()) { |
|
|
|
auto param = GetTensorQuantParam(tensor); |
|
|
|
if (param != nullptr && param->inited) { |
|
|
|
MS_LOG(DEBUG) << "tensor " << inputTensorIdx << " already has quantParam"; |
|
|
|
MS_LOG(DEBUG) << "tensor " << inputTensorIdx |
|
|
|
<< " already has quantParam"; |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
tensor->quantParams.clear(); |
|
|
|
@@ -83,11 +88,14 @@ struct InputArray { |
|
|
|
}; |
|
|
|
|
|
|
|
const std::array<schema::PrimitiveType, 7> AwareQuantizer::propagatedOps = { |
|
|
|
{schema::PrimitiveType_Concat, schema::PrimitiveType_Resize, schema::PrimitiveType_Reshape, |
|
|
|
schema::PrimitiveType_Squeeze, schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation, |
|
|
|
schema::PrimitiveType_DetectionPostProcess}}; |
|
|
|
{schema::PrimitiveType_Concat, schema::PrimitiveType_Resize, |
|
|
|
schema::PrimitiveType_Reshape, schema::PrimitiveType_Squeeze, |
|
|
|
schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation, |
|
|
|
schema::PrimitiveType_DetectionPostProcess}}; |
|
|
|
|
|
|
|
AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const string &inputInferType, const string &stdValues, |
|
|
|
AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, |
|
|
|
const string &inputInferType, |
|
|
|
const string &stdValues, |
|
|
|
const string &meanValues) |
|
|
|
: FbQuantizer(graph) { |
|
|
|
MS_ASSERT(graph != nullptr); |
|
|
|
@@ -110,9 +118,11 @@ STATUS AwareQuantizer::RemoveFakeQuant() { |
|
|
|
// MS_LOGE("GenerateDefaultQuantParam failed: %d", status); |
|
|
|
// return RET_ERROR; |
|
|
|
// } |
|
|
|
// for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end(); iter++) { |
|
|
|
// for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end(); |
|
|
|
// iter++) { |
|
|
|
// auto *node = (*iter).get(); |
|
|
|
// if (GetCNodeTType(*node) != OpT_FakeQuantWithMinMaxVars && GetCNodeTType(*node) != OpT_FakeQuantWithMinMax) { |
|
|
|
// if (GetCNodeTType(*node) != OpT_FakeQuantWithMinMaxVars && |
|
|
|
// GetCNodeTType(*node) != OpT_FakeQuantWithMinMax) { |
|
|
|
// continue; |
|
|
|
// } |
|
|
|
// auto inputIndexes = node->inputIndex; |
|
|
|
@@ -144,41 +154,43 @@ STATUS AwareQuantizer::RemoveFakeQuant() { |
|
|
|
// auto *maxData = reinterpret_cast<const float *>(tensor2->data.data()); |
|
|
|
// MS_ASSERT(minData != nullptr); |
|
|
|
// MS_ASSERT(maxData != nullptr); |
|
|
|
// std::unique_ptr<QuantParamT> quantParam(new (std::nothrow) QuantParamT()); |
|
|
|
// if (quantParam == nullptr) { |
|
|
|
// std::unique_ptr<QuantParamT> quantParam(new (std::nothrow) |
|
|
|
// QuantParamT()); if (quantParam == nullptr) { |
|
|
|
// MS_LOGE("new quantParam failed"); |
|
|
|
// return RET_ERROR; |
|
|
|
// } |
|
|
|
// auto realMin = (double)minData[0]; |
|
|
|
// auto realMax = (double)maxData[0]; |
|
|
|
// status = CalQuantizationParams(quantParam.get(), realMin, realMax, narrorRange, numBits); |
|
|
|
// if (status != RET_OK) { |
|
|
|
// MS_LOGE("in aware quantization run CalQuantizationParams failed, node: %s", node->name.c_str()); |
|
|
|
// return RET_ERROR; |
|
|
|
// status = CalQuantizationParams(quantParam.get(), realMin, realMax, |
|
|
|
// narrorRange, numBits); if (status != RET_OK) { |
|
|
|
// MS_LOGE("in aware quantization run CalQuantizationParams failed, |
|
|
|
// node: %s", node->name.c_str()); return RET_ERROR; |
|
|
|
// } |
|
|
|
// if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT) { |
|
|
|
// CalFakeNode(tensor0, quantParam.get()); |
|
|
|
// } |
|
|
|
// std::unique_ptr<QuantParamArrayT> quantParamArray(new (std::nothrow) QuantParamArrayT()); |
|
|
|
// if (quantParamArray == nullptr) { |
|
|
|
// std::unique_ptr<QuantParamArrayT> quantParamArray(new (std::nothrow) |
|
|
|
// QuantParamArrayT()); if (quantParamArray == nullptr) { |
|
|
|
// MS_LOGE("new quantParamArray failed"); |
|
|
|
// return RET_ERROR; |
|
|
|
// } |
|
|
|
// quantParamArray->param.push_back(std::move(quantParam)); |
|
|
|
// auto quantParamArrayCopy = CopyQuantParamArrayT(quantParamArray); |
|
|
|
// if (quantParamArrayCopy == nullptr) { |
|
|
|
// MS_LOGE("CopyQuantParamArray %s return nullptr", iter->get()->name.c_str()); |
|
|
|
// return RET_ERROR; |
|
|
|
// MS_LOGE("CopyQuantParamArray %s return nullptr", |
|
|
|
// iter->get()->name.c_str()); return RET_ERROR; |
|
|
|
// } |
|
|
|
// node->quantParam.emplace_back(std::move(quantParamArrayCopy)); |
|
|
|
// node->quantParam.emplace_back(nullptr); // secondInTensor and thirdInTensor are weightTensors who have no |
|
|
|
// preNode node->quantParam.emplace_back(nullptr); node->quantParam.emplace_back(std::move(quantParamArray)); |
|
|
|
// node->quantParam.emplace_back(nullptr); // secondInTensor and |
|
|
|
// thirdInTensor are weightTensors who have no preNode |
|
|
|
// node->quantParam.emplace_back(nullptr); |
|
|
|
// node->quantParam.emplace_back(std::move(quantParamArray)); |
|
|
|
// |
|
|
|
// // BroadCast fakeQuantNode QuantParam |
|
|
|
// status = BroadCastQuantParam(subGraph, *iter); |
|
|
|
// if (status != RET_OK) { |
|
|
|
// MS_LOGE("BroadCastQuantParam %s failed: %d", iter->get()->name.c_str(), status); |
|
|
|
// return status; |
|
|
|
// MS_LOGE("BroadCastQuantParam %s failed: %d", |
|
|
|
// iter->get()->name.c_str(), status); return status; |
|
|
|
// } |
|
|
|
// // save post node index for SetAttrToConvolution |
|
|
|
// auto postNodeIdxes = GetOutputNodeIdx(*subGraph, *node); |
|
|
|
@@ -189,10 +201,13 @@ STATUS AwareQuantizer::RemoveFakeQuant() { |
|
|
|
// return RET_ERROR; |
|
|
|
// } |
|
|
|
// // set filter param to node |
|
|
|
// if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT && !postNodeIdxes.empty()) { |
|
|
|
// if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT && |
|
|
|
// !postNodeIdxes.empty()) { |
|
|
|
// auto postNode = subGraph->nodes.at(postNodeIdxes.front()).get(); |
|
|
|
// if (GetCNodeTType(*postNode) == OpT_Conv2D || GetCNodeTType(*postNode) == OpT_DepthwiseConv2D || |
|
|
|
// GetCNodeTType(*postNode) == OpT_DeConv2D || GetCNodeTType(*postNode) == OpT_DeDepthwiseConv2D) { |
|
|
|
// if (GetCNodeTType(*postNode) == OpT_Conv2D || |
|
|
|
// GetCNodeTType(*postNode) == OpT_DepthwiseConv2D || |
|
|
|
// GetCNodeTType(*postNode) == OpT_DeConv2D || |
|
|
|
// GetCNodeTType(*postNode) == OpT_DeDepthwiseConv2D) { |
|
|
|
// auto status = SetAttrToConvolution(subGraph.get(), postNode); |
|
|
|
// if (status != RET_OK) { |
|
|
|
// MS_LOGE("in aware quant SetAttrToConvolution failed!"); |
|
|
|
@@ -203,7 +218,8 @@ STATUS AwareQuantizer::RemoveFakeQuant() { |
|
|
|
// } |
|
|
|
// |
|
|
|
// // remove IsolatedNode |
|
|
|
// for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end();) { |
|
|
|
// for (auto iter = subGraph->nodes.begin(); iter != |
|
|
|
// subGraph->nodes.end();) { |
|
|
|
// if ((*iter)->inputIndex.empty() && (*iter)->outputIndex.empty()) { |
|
|
|
// iter = subGraph->nodes.erase(iter); |
|
|
|
// } else { |
|
|
|
@@ -213,8 +229,8 @@ STATUS AwareQuantizer::RemoveFakeQuant() { |
|
|
|
// // set graphInputNode inputTensor quantParams |
|
|
|
// MS_ASSERT(subGraph->inputIndex.size() == 1); |
|
|
|
// for (auto graphInputIndex : subGraph->inputIndex) { |
|
|
|
// auto linkedPostIdx = GetLinkedPostIdx(*(subGraph.get()), graphInputIndex); |
|
|
|
// for (auto nodeIdx : linkedPostIdx) { |
|
|
|
// auto linkedPostIdx = GetLinkedPostIdx(*(subGraph.get()), |
|
|
|
// graphInputIndex); for (auto nodeIdx : linkedPostIdx) { |
|
|
|
// MS_ASSERT(subGraph->nodes.size() > nodeIdx); |
|
|
|
// mInputArray->SetInputArrayQP(subGraph->nodes.at(nodeIdx).get()); |
|
|
|
// } |
|
|
|
@@ -223,7 +239,8 @@ STATUS AwareQuantizer::RemoveFakeQuant() { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph) { |
|
|
|
STATUS AwareQuantizer::GenerateDefaultQuantParam( |
|
|
|
const schema::MetaGraphT *subGraph) { |
|
|
|
MS_ASSERT(subGraph != nullptr); |
|
|
|
for (const auto &tensor : subGraph->allTensors) { |
|
|
|
if (!tensor->quantParams.empty()) { |
|
|
|
@@ -235,15 +252,18 @@ STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGr |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { |
|
|
|
STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, |
|
|
|
schema::CNodeT *node) { |
|
|
|
// MS_ASSERT(subGraph != nullptr); |
|
|
|
// MS_ASSERT(node != nullptr); |
|
|
|
// auto inputIndexes = node->inputIndex; |
|
|
|
// MS_ASSERT(GetCNodeTType(*node) == OpT_Conv2D || GetCNodeTType(*node) == OpT_DepthwiseConv2D || |
|
|
|
// GetCNodeTType(*node) == OpT_DeConv2D || GetCNodeTType(*node) == OpT_DeDepthwiseConv2D); |
|
|
|
// MS_ASSERT(GetCNodeTType(*node) == OpT_Conv2D || GetCNodeTType(*node) == |
|
|
|
// OpT_DepthwiseConv2D || |
|
|
|
// GetCNodeTType(*node) == OpT_DeConv2D || GetCNodeTType(*node) == |
|
|
|
// OpT_DeDepthwiseConv2D); |
|
|
|
// if (inputIndexes.size() < 2) { |
|
|
|
// MS_LOGE("in aware quant %s node's input tensors is invalid(%zu)!", node->name.c_str(), inputIndexes.size()); |
|
|
|
// return RET_ERROR; |
|
|
|
// MS_LOGE("in aware quant %s node's input tensors is invalid(%zu)!", |
|
|
|
// node->name.c_str(), inputIndexes.size()); return RET_ERROR; |
|
|
|
// } |
|
|
|
// TensorDefT *filterTensor = subGraph->allTensors.at(inputIndexes[1]).get(); |
|
|
|
// MS_ASSERT(filterTensor != nullptr); |
|
|
|
@@ -267,14 +287,16 @@ STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, |
|
|
|
// if (GetCNodeTType(*node) == OpT_DepthwiseConv2D) { |
|
|
|
// if (node->fmkType == FmkType_MS) { |
|
|
|
// node->attr.AsDepthwiseConv2D()->channelIn = (int32_t)filterDims[0]; |
|
|
|
// node->attr.AsDepthwiseConv2D()->channelMultiplier = (int32_t)filterDims[1]; |
|
|
|
// node->attr.AsDepthwiseConv2D()->kernelH = (int32_t)filterDims[2]; |
|
|
|
// node->attr.AsDepthwiseConv2D()->kernelW = (int32_t)filterDims[3]; |
|
|
|
// node->attr.AsDepthwiseConv2D()->channelMultiplier = |
|
|
|
// (int32_t)filterDims[1]; node->attr.AsDepthwiseConv2D()->kernelH = |
|
|
|
// (int32_t)filterDims[2]; node->attr.AsDepthwiseConv2D()->kernelW = |
|
|
|
// (int32_t)filterDims[3]; |
|
|
|
// } else if (node->fmkType == FmkType_TF) { |
|
|
|
// node->attr.AsDepthwiseConv2D()->kernelH = (int32_t)filterDims[0]; |
|
|
|
// node->attr.AsDepthwiseConv2D()->kernelW = (int32_t)filterDims[1]; |
|
|
|
// node->attr.AsDepthwiseConv2D()->channelIn = (int32_t)filterDims[2]; |
|
|
|
// node->attr.AsDepthwiseConv2D()->channelMultiplier = (int32_t)filterDims[3]; |
|
|
|
// node->attr.AsDepthwiseConv2D()->channelMultiplier = |
|
|
|
// (int32_t)filterDims[3]; |
|
|
|
// } else { |
|
|
|
// MS_LOGE("Unsupport"); |
|
|
|
// } |
|
|
|
@@ -313,15 +335,19 @@ STATUS AwareQuantizer::GenerateQuantParam() { |
|
|
|
GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { |
|
|
|
MS_ASSERT(false); |
|
|
|
} |
|
|
|
auto *quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); |
|
|
|
auto *quantParamCalcer = |
|
|
|
quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); |
|
|
|
if (quantParamCalcer == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str() |
|
|
|
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; |
|
|
|
MS_LOG(ERROR) << "Can not find QuantParamCalcer for " |
|
|
|
<< node->name.c_str() |
|
|
|
<< ", type: " << GetCNodeTTypeName(*node).c_str() |
|
|
|
<< " set node to QuantNone and skip"; |
|
|
|
node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE); |
|
|
|
} else { |
|
|
|
status = quantParamCalcer->Calc(graph, *node); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); |
|
|
|
MS_LOG(ERROR) << "quantParamCalcer failed: " << status |
|
|
|
<< " node: " << node->name.c_str(); |
|
|
|
node->quantType = schema::QuantType_QUANT_NONE; |
|
|
|
} else { |
|
|
|
node->quantType = schema::QuantType_AwareTraining; |
|
|
|
@@ -345,7 +371,8 @@ STATUS AwareQuantizer::DoQuantize() { |
|
|
|
GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D) { |
|
|
|
auto inputIndexes = node->inputIndex; |
|
|
|
if (inputIndexes.size() < 2) { |
|
|
|
MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count"; |
|
|
|
MS_LOG(ERROR) << node->name.c_str() |
|
|
|
<< " node input has invalid inputs tensor count"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
// quant weight |
|
|
|
@@ -362,7 +389,8 @@ STATUS AwareQuantizer::DoQuantize() { |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
} else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) { |
|
|
|
} else if (GetCNodeTType(*node) == |
|
|
|
schema::PrimitiveType_DetectionPostProcess) { |
|
|
|
status = QuantDetectionPostProcessConstTensor(graph, node.get()); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!"; |
|
|
|
@@ -388,7 +416,8 @@ STATUS AwareQuantizer::DoQuantize() { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) { |
|
|
|
STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, |
|
|
|
schema::CNodeT *node) { |
|
|
|
MS_ASSERT(graph != nullptr); |
|
|
|
MS_ASSERT(node != nullptr); |
|
|
|
for (size_t i = 0; i < node->inputIndex.size(); i++) { |
|
|
|
@@ -407,7 +436,8 @@ STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, sche |
|
|
|
void *inData = inTensor->data.data(); |
|
|
|
auto *castedInData = static_cast<float *>(inData); |
|
|
|
for (size_t j = 0; j < constTensorShapeSize; j++) { |
|
|
|
qDatas[j] = QuantizeData<uint8_t>(castedInData[j], quantParam.get()); |
|
|
|
qDatas[j] = |
|
|
|
QuantizeData<uint8_t>(castedInData[j], quantParam.get()); |
|
|
|
} |
|
|
|
inTensor->data = std::move(qDatas); |
|
|
|
inTensor->dataType = kNumberTypeUInt8; |
|
|
|
@@ -423,14 +453,17 @@ STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, sche |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { |
|
|
|
STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor( |
|
|
|
const schema::MetaGraphT *subGraph, schema::CNodeT *node) { |
|
|
|
MS_ASSERT(subGraph != nullptr); |
|
|
|
MS_ASSERT(node != nullptr); |
|
|
|
auto &constTensor = subGraph->allTensors.at(node->inputIndex[2]); |
|
|
|
MS_ASSERT(constTensor != nullptr); |
|
|
|
const auto *constData = reinterpret_cast<const float *>(constTensor->data.data()); |
|
|
|
const auto *constData = |
|
|
|
reinterpret_cast<const float *>(constTensor->data.data()); |
|
|
|
|
|
|
|
if (constTensor->refCount == 999 && constTensor->dataType == TypeId::kNumberTypeFloat) { |
|
|
|
if (constTensor->refCount == 999 && |
|
|
|
constTensor->dataType == TypeId::kNumberTypeFloat) { |
|
|
|
size_t constTensorShapeSize = GetShapeSize(*constTensor); |
|
|
|
std::unique_ptr<QuantParamT> quantParam = GetTensorQuantParam(constTensor); |
|
|
|
if (quantParam == nullptr) { |
|
|
|
@@ -448,7 +481,8 @@ STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGr |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, mindspore::schema::CNodeT *node) { |
|
|
|
STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, |
|
|
|
mindspore::schema::CNodeT *node) { |
|
|
|
MS_ASSERT(graph != nullptr); |
|
|
|
MS_ASSERT(node != nullptr); |
|
|
|
auto inputIndexes = node->inputIndex; |
|
|
|
@@ -507,7 +541,8 @@ STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, |
|
|
|
biasTensor->dataType = TypeId::kNumberTypeInt32; |
|
|
|
biasTensor->data.clear(); |
|
|
|
biasTensor->data.resize(bShapeSize * sizeof(int32_t)); |
|
|
|
auto ret = memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t), qDatas, bShapeSize * sizeof(int32_t)); |
|
|
|
auto ret = memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t), |
|
|
|
qDatas, bShapeSize * sizeof(int32_t)); |
|
|
|
if (ret != EOK) { |
|
|
|
// MS_LOGE("memcpy_s failed: %d", ret); |
|
|
|
return RET_ERROR; |
|
|
|
@@ -516,10 +551,12 @@ STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { |
|
|
|
STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, |
|
|
|
schema::CNodeT *node) { |
|
|
|
MS_ASSERT(subGraph != nullptr); |
|
|
|
MS_ASSERT(node != nullptr); |
|
|
|
MS_ASSERT(node->quantParam.size() == node->inputIndex.size() + node->outputIndex.size()); |
|
|
|
MS_ASSERT(node->quantParam.size() == |
|
|
|
node->inputIndex.size() + node->outputIndex.size()); |
|
|
|
auto inputIndexes = node->inputIndex; |
|
|
|
MS_ASSERT(inputIndexes.size() >= 2); |
|
|
|
MS_ASSERT(subGraph->allTensors.size() > inputIndexes.at(1)); |
|
|
|
@@ -527,8 +564,11 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schem |
|
|
|
if (weightTensor->dataType == TypeId::kNumberTypeInt8) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
if (weightTensor->dataType != TypeId::kNumberTypeFloat && weightTensor->dataType != TypeId::kNumberTypeUInt8) { |
|
|
|
MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8"; |
|
|
|
if (weightTensor->dataType != TypeId::kNumberTypeFloat32 && |
|
|
|
weightTensor->dataType != TypeId::kNumberTypeFloat && |
|
|
|
weightTensor->dataType != TypeId::kNumberTypeUInt8) { |
|
|
|
MS_LOG(ERROR) << "conv " << node->name.c_str() |
|
|
|
<< "'s weight data is not float or uint8"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
size_t wShapeSize = GetShapeSize(*(weightTensor.get())); |
|
|
|
@@ -536,7 +576,8 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schem |
|
|
|
MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr); |
|
|
|
vector<int8_t> qDatas(wShapeSize); |
|
|
|
auto weightQauntParam = GetTensorQuantParam(weightTensor); |
|
|
|
if (weightTensor->dataType == TypeId::kNumberTypeFloat) { // normal awareing quant |
|
|
|
if (weightTensor->dataType == |
|
|
|
TypeId::kNumberTypeFloat) { // normal awareing quant |
|
|
|
auto *weightData = static_cast<float *>(oriWeightData); |
|
|
|
for (size_t j = 0; j < wShapeSize; j++) { |
|
|
|
qDatas[j] = QuantizeData<int8_t>(weightData[j], weightQauntParam.get()); |
|
|
|
@@ -564,7 +605,8 @@ STATUS AwareQuantizer::DetermineNodeQuantType() { |
|
|
|
MS_ASSERT(graph->allTensors.size() > inTensorIdx); |
|
|
|
auto &inTensor = graph->allTensors.at(inTensorIdx); |
|
|
|
MS_ASSERT(inTensor != nullptr); |
|
|
|
if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr || |
|
|
|
if (inTensor->quantParams.empty() || |
|
|
|
inTensor->quantParams.front() == nullptr || |
|
|
|
!inTensor->quantParams.front()->inited) { |
|
|
|
canQuant = false; |
|
|
|
break; |
|
|
|
@@ -576,7 +618,8 @@ STATUS AwareQuantizer::DetermineNodeQuantType() { |
|
|
|
MS_ASSERT(graph->allTensors.size() > outTensorIdx); |
|
|
|
auto &outTensor = graph->allTensors.at(outTensorIdx); |
|
|
|
MS_ASSERT(outTensor != nullptr); |
|
|
|
if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr || |
|
|
|
if (outTensor->quantParams.empty() || |
|
|
|
outTensor->quantParams.front() == nullptr || |
|
|
|
!outTensor->quantParams.front()->inited) { |
|
|
|
canQuant = false; |
|
|
|
break; |
|
|
|
|