|
|
|
@@ -308,21 +308,24 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl |
|
|
|
|
|
|
|
STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum, |
|
|
|
bool per_channel) { |
|
|
|
if (per_channel) { |
|
|
|
// per channel |
|
|
|
auto dims = weightPtr->tensor_shape(); |
|
|
|
if (dims.size() < 1) { |
|
|
|
MS_LOG(ERROR) << "weight dims size error"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
// todo(x) |
|
|
|
auto dims = weightPtr->tensor_shape(); |
|
|
|
if (dims.size() != 4) { |
|
|
|
MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer."; |
|
|
|
per_channel = false; |
|
|
|
} else { |
|
|
|
uint32_t channels = dims[3]; |
|
|
|
if (channels == 0) { |
|
|
|
MS_LOG(ERROR) << "channels error 0"; |
|
|
|
MS_LOG(ERROR) << "channels is 0"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (per_channel) { |
|
|
|
// notice: |
|
|
|
// at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D |
|
|
|
// if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK |
|
|
|
size_t shapeSize = weightPtr->tensor_shape_size(); |
|
|
|
auto channels = dims[3]; |
|
|
|
size_t oneFilterSize = shapeSize / channels; |
|
|
|
auto *rawDatas = reinterpret_cast<const float *>(weightPtr->tensor_addr()); |
|
|
|
if (rawDatas == nullptr) { |
|
|
|
@@ -330,17 +333,17 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_ |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
float min = FLT_MAX; |
|
|
|
float max = FLT_MIN; |
|
|
|
weightPtr->quant_param().clear(); |
|
|
|
vector<int8_t> qDatas(shapeSize); |
|
|
|
|
|
|
|
for (uint32_t i = 0; i < channels; i++) { |
|
|
|
float min = 0; |
|
|
|
float max = 0; |
|
|
|
// find min and max |
|
|
|
for (uint32_t j = 0; j < oneFilterSize; j++) { |
|
|
|
min = std::min(min, rawDatas[j + i * oneFilterSize]); |
|
|
|
max = std::max(max, rawDatas[j + i * oneFilterSize]); |
|
|
|
min = std::min(min, rawDatas[i + j * oneFilterSize]); |
|
|
|
max = std::max(max, rawDatas[i + j * oneFilterSize]); |
|
|
|
} |
|
|
|
|
|
|
|
std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam); |
|
|
|
STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); |
|
|
|
if (status != RET_OK) { |
|
|
|
@@ -349,11 +352,10 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_ |
|
|
|
} |
|
|
|
// update data and datatype |
|
|
|
for (uint32_t j = 0; j < oneFilterSize; j++) { |
|
|
|
float rawData = rawDatas[j + i * oneFilterSize]; |
|
|
|
float rawData = rawDatas[i + j * oneFilterSize]; |
|
|
|
auto qData = QuantizeData<int8_t>(rawData, quantParam.get(), quant_max, quant_min); |
|
|
|
qDatas[j + i * oneFilterSize] = qData; |
|
|
|
qDatas[i + j * oneFilterSize] = qData; |
|
|
|
} |
|
|
|
|
|
|
|
weightPtr->set_quant_param(quantParam); |
|
|
|
} |
|
|
|
auto ret = |
|
|
|
|