| @@ -53,10 +53,22 @@ ConvDwInt8PostAlign4: | |||||
| sqrdmulh v2.4s, v2.4s, v27.4s | sqrdmulh v2.4s, v2.4s, v27.4s | ||||
| sqrdmulh v3.4s, v3.4s, v27.4s | sqrdmulh v3.4s, v3.4s, v27.4s | ||||
| sqrshl v0.4s, v0.4s, v28.4s | |||||
| sqrshl v1.4s, v1.4s, v28.4s | |||||
| sqrshl v2.4s, v2.4s, v28.4s | |||||
| sqrshl v3.4s, v3.4s, v28.4s | |||||
| and v4.16b, v0.16b, v28.16b | |||||
| sshr v4.4s, v4.4s, #31 | |||||
| sqadd v0.4s, v0.4s, v4.4s | |||||
| srshl v0.4s, v0.4s, v28.4s | |||||
| and v5.16b, v1.16b, v28.16b | |||||
| sshr v5.4s, v5.4s, #31 | |||||
| sqadd v1.4s, v1.4s, v5.4s | |||||
| srshl v1.4s, v1.4s, v28.4s | |||||
| and v6.16b, v2.16b, v28.16b | |||||
| sshr v6.4s, v6.4s, #31 | |||||
| sqadd v2.4s, v2.4s, v6.4s | |||||
| srshl v2.4s, v2.4s, v28.4s | |||||
| and v7.16b, v3.16b, v28.16b | |||||
| sshr v7.4s, v7.4s, #31 | |||||
| sqadd v3.4s, v3.4s, v7.4s | |||||
| srshl v3.4s, v3.4s, v28.4s | |||||
| AddZpDepth16: | AddZpDepth16: | ||||
| add v0.4s, v0.4s, v29.4s | add v0.4s, v0.4s, v29.4s | ||||
| @@ -109,8 +121,14 @@ ConvDwInt8PostAlign4: | |||||
| RightShiftDepth8: | RightShiftDepth8: | ||||
| sqrdmulh v0.4s, v0.4s, v27.4s | sqrdmulh v0.4s, v0.4s, v27.4s | ||||
| sqrdmulh v1.4s, v1.4s, v27.4s | sqrdmulh v1.4s, v1.4s, v27.4s | ||||
| sqrshl v0.4s, v0.4s, v28.4s | |||||
| sqrshl v1.4s, v1.4s, v28.4s | |||||
| and v4.16b, v0.16b, v28.16b | |||||
| sshr v4.4s, v4.4s, #31 | |||||
| sqadd v0.4s, v0.4s, v4.4s | |||||
| srshl v0.4s, v0.4s, v28.4s | |||||
| and v5.16b, v1.16b, v28.16b | |||||
| sshr v5.4s, v5.4s, #31 | |||||
| sqadd v1.4s, v1.4s, v5.4s | |||||
| srshl v1.4s, v1.4s, v28.4s | |||||
| AddZpDepth8: | AddZpDepth8: | ||||
| add v0.4s, v0.4s, v29.4s | add v0.4s, v0.4s, v29.4s | ||||
| @@ -140,7 +158,10 @@ ConvDwInt8PostAlign4: | |||||
| sqshl v0.4s, v0.4s, v26.4s | sqshl v0.4s, v0.4s, v26.4s | ||||
| sqrdmulh v0.4s, v0.4s, v27.4s | sqrdmulh v0.4s, v0.4s, v27.4s | ||||
| sqrshl v0.4s, v0.4s, v28.4s | |||||
| and v4.16b, v0.16b, v28.16b | |||||
| sshr v4.4s, v4.4s, #31 | |||||
| sqadd v0.4s, v0.4s, v4.4s | |||||
| srshl v0.4s, v0.4s, v28.4s | |||||
| add v0.4s, v0.4s, v29.4s | add v0.4s, v0.4s, v29.4s | ||||
| smax v0.4s, v0.4s, v30.4s | smax v0.4s, v0.4s, v30.4s | ||||
| @@ -43,8 +43,14 @@ ConvDwInt8PostAlign4PerChannel: | |||||
| sqrdmulh v0.4s, v0.4s, v4.4s | sqrdmulh v0.4s, v0.4s, v4.4s | ||||
| sqrdmulh v1.4s, v1.4s, v5.4s | sqrdmulh v1.4s, v1.4s, v5.4s | ||||
| sqrshl v0.4s, v0.4s, v6.4s | |||||
| sqrshl v1.4s, v1.4s, v7.4s | |||||
| and v16.16b, v0.16b, v6.16b | |||||
| sshr v16.4s, v16.4s, #31 | |||||
| sqadd v0.4s, v0.4s, v16.4s | |||||
| srshl v0.4s, v0.4s, v6.4s | |||||
| and v17.16b, v1.16b, v7.16b | |||||
| sshr v17.4s, v17.4s, #31 | |||||
| sqadd v1.4s, v1.4s, v17.4s | |||||
| srshl v1.4s, v1.4s, v7.4s | |||||
| add v0.4s, v0.4s, v29.4s | add v0.4s, v0.4s, v29.4s | ||||
| add v1.4s, v1.4s, v29.4s | add v1.4s, v1.4s, v29.4s | ||||
| @@ -80,7 +86,10 @@ ConvDwInt8PostAlign4PerChannel: | |||||
| sqrdmulh v0.4s, v0.4s, v4.4s | sqrdmulh v0.4s, v0.4s, v4.4s | ||||
| ld1 {v6.4s}, [x6], #16 | ld1 {v6.4s}, [x6], #16 | ||||
| sqrshl v0.4s, v0.4s, v6.4s | |||||
| and v16.16b, v0.16b, v6.16b | |||||
| sshr v16.4s, v16.4s, #31 | |||||
| sqadd v0.4s, v0.4s, v16.4s | |||||
| srshl v0.4s, v0.4s, v6.4s | |||||
| add v0.4s, v0.4s, v29.4s | add v0.4s, v0.4s, v29.4s | ||||
| smax v0.4s, v0.4s, v30.4s | smax v0.4s, v0.4s, v30.4s | ||||
| @@ -29,17 +29,24 @@ int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { | |||||
| int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, | |||||
| bool uint8_flag) { | |||||
| if (quant_values == NULL || real_values == NULL) { | if (quant_values == NULL || real_values == NULL) { | ||||
| return NNACL_PARAM_INVALID; | return NNACL_PARAM_INVALID; | ||||
| } | } | ||||
| if (uint8_flag) { | |||||
| zp += 128; | |||||
| } | |||||
| const float inverse_scale = 1.0f / scale; | const float inverse_scale = 1.0f / scale; | ||||
| for (int i = 0; i < size; ++i) { | for (int i = 0; i < size; ++i) { | ||||
| if (isinf(real_values[i])) { | if (isinf(real_values[i])) { | ||||
| quant_values[i] = 127; | quant_values[i] = 127; | ||||
| } else { | } else { | ||||
| int temp = round(real_values[i] * inverse_scale + zp); | int temp = round(real_values[i] * inverse_scale + zp); | ||||
| if (uint8_flag) { | |||||
| temp -= 128; | |||||
| } | |||||
| temp = temp < 127 ? temp : 127; | temp = temp < 127 ? temp : 127; | ||||
| temp = temp > -128 ? temp : -128; | temp = temp > -128 ? temp : -128; | ||||
| quant_values[i] = (int8_t)temp; | quant_values[i] = (int8_t)temp; | ||||
| @@ -29,7 +29,8 @@ typedef struct QuantDTypeCastParameter { | |||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); | int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); | ||||
| int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size); | |||||
| int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, | |||||
| bool uint8_flag); | |||||
| int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size); | int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size); | ||||
| int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size); | int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size); | ||||
| int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size); | int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size); | ||||
| @@ -80,6 +80,12 @@ typedef struct OpParameter { | |||||
| typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, ActType_Prelu } ActType; | typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, ActType_Prelu } ActType; | ||||
| typedef enum PadMode { Pad_No, Pad_Same, Pad_Valid } PadMode; | typedef enum PadMode { Pad_No, Pad_Same, Pad_Valid } PadMode; | ||||
| typedef enum RoundingMode { Rounding_No, Rounding_Away_from_zero, Rounding_Up } RoundingMode; | |||||
| typedef enum CalFixedMultiplierMode { | |||||
| Method_No, | |||||
| Method_SinglePrecision, | |||||
| Method_DoublePrecision | |||||
| } CalFixedMultiplierMode; | |||||
| #ifdef ENABLE_ARM | #ifdef ENABLE_ARM | ||||
| #define MS_FLOAT32X4 float32x4_t | #define MS_FLOAT32X4 float32x4_t | ||||
| @@ -42,7 +42,7 @@ int16_t SaturatingRoundingDoublingHighMulInt16(int16_t a, int16_t b) { | |||||
| } | } | ||||
| // division by a 2^exponent with rounding | // division by a 2^exponent with rounding | ||||
| // or arithmetic right shift with rouding | |||||
| // or arithmetic right shift with rounding | |||||
| int RoundingDivideByPOT(int x, int exponent) { | int RoundingDivideByPOT(int x, int exponent) { | ||||
| const int mask = (1ll << exponent) - 1; | const int mask = (1ll << exponent) - 1; | ||||
| const int remainder = x & mask; | const int remainder = x & mask; | ||||
| @@ -50,10 +50,23 @@ int RoundingDivideByPOT(int x, int exponent) { | |||||
| return (x >> exponent) + (remainder > threshold ? 1 : 0); | return (x >> exponent) + (remainder > threshold ? 1 : 0); | ||||
| } | } | ||||
| int UpwardRounding(int x, int exponent) { | |||||
| const int32_t rounding_offset = (exponent > 0) ? (1 << (exponent - 1)) : 0; | |||||
| if (x > INT32_MAX - rounding_offset) { | |||||
| return 1 << (31 - exponent); | |||||
| } | |||||
| return (x + rounding_offset) >> exponent; | |||||
| } | |||||
| int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift) { | int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift) { | ||||
| return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); | return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); | ||||
| } | } | ||||
| int MultiplyByQuantizedMultiplierWithUpwardRounding(int32_t value, int32_t multiplier, int32_t left_shift, | |||||
| int32_t right_shift) { | |||||
| return UpwardRounding(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); | |||||
| } | |||||
| int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift) { | int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift) { | ||||
| return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value, multiplier), right_shift); | return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value, multiplier), right_shift); | ||||
| } | } | ||||
| @@ -40,8 +40,13 @@ int16_t SaturatingRoundingDoublingHighMulInt16(int16_t a, int16_t b); | |||||
| // or arithmetic right shift with rouding | // or arithmetic right shift with rouding | ||||
| int RoundingDivideByPOT(int x, int exponent); | int RoundingDivideByPOT(int x, int exponent); | ||||
| int UpwardRounding(int x, int exponent); | |||||
| int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift); | int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift); | ||||
| int MultiplyByQuantizedMultiplierWithUpwardRounding(int32_t value, int32_t multiplier, int32_t left_shift, | |||||
| int32_t right_shift); | |||||
| int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift); | int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift); | ||||
| int SaturatingRoundingMultiplyByPOT(int32_t x, int exponent); | int SaturatingRoundingMultiplyByPOT(int32_t x, int exponent); | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "nnacl/quantization/quantize.h" | #include "nnacl/quantization/quantize.h" | ||||
| #include <stdio.h> | |||||
| const uint64_t dSignMask = 1ull << 63; | const uint64_t dSignMask = 1ull << 63; | ||||
| const uint64_t dExponentMask = 0x7ffull << 52; | const uint64_t dExponentMask = 0x7ffull << 52; | ||||
| @@ -35,8 +36,8 @@ void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantiz | |||||
| *right_shift = -shift; | *right_shift = -shift; | ||||
| } | } | ||||
| void QuantizeRoundParameter(double double_multiplier, int32_t *quantized_multiplier, int *left_shift, | |||||
| int *right_shift) { | |||||
| void QuantizeRoundParameterWithDoublePrecision(double double_multiplier, int32_t *quantized_multiplier, int *left_shift, | |||||
| int *right_shift) { | |||||
| int shift = 0; | int shift = 0; | ||||
| QuantizeMultiplierSmallerThanOne(double_multiplier, quantized_multiplier, &shift); | QuantizeMultiplierSmallerThanOne(double_multiplier, quantized_multiplier, &shift); | ||||
| shift = -shift; | shift = -shift; | ||||
| @@ -49,6 +50,29 @@ void QuantizeRoundParameter(double double_multiplier, int32_t *quantized_multipl | |||||
| } | } | ||||
| } | } | ||||
| void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t *quantized_multiplier, int *left_shift, | |||||
| int *right_shift) { | |||||
| int shift = 0; | |||||
| const uint32_t scale_bits = (uint32_t)(double_multiplier); | |||||
| /* multipiler is in[0x40000000, 0x7FFFFF80] range */ | |||||
| *quantized_multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); | |||||
| if (quantized_multiplier[0] < INT32_C(0x40000000) || quantized_multiplier[0] > INT32_C(0x7FFFFF80)) { | |||||
| printf("quantized multiplier must be in [0x40000000, 0x7FFFFF80] range, now multiplier is %d\n", | |||||
| quantized_multiplier[0]); | |||||
| return; | |||||
| } | |||||
| /* shift is in [0, 31] range */ | |||||
| shift = 127 + 31 - 32 - ((uint32_t)(double_multiplier) >> 23); | |||||
| shift = -shift; | |||||
| if (shift < 0) { | |||||
| *left_shift = 0; | |||||
| *right_shift = shift; | |||||
| } else { | |||||
| *left_shift = shift; | |||||
| *right_shift = 0; | |||||
| } | |||||
| } | |||||
| uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } | uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } | ||||
| int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } | int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } | ||||
| @@ -34,6 +34,8 @@ typedef struct QuantArg { | |||||
| } QuantArg; | } QuantArg; | ||||
| typedef struct ConvQuantArg { | typedef struct ConvQuantArg { | ||||
| RoundingMode round_mode_; | |||||
| CalFixedMultiplierMode quant_multiplier_mode_; | |||||
| QuantArg *input_quant_args_; | QuantArg *input_quant_args_; | ||||
| QuantArg *filter_quant_args_; | QuantArg *filter_quant_args_; | ||||
| QuantArg *output_quant_args_; | QuantArg *output_quant_args_; | ||||
| @@ -46,7 +48,6 @@ typedef struct ConvQuantArg { | |||||
| size_t input_arg_num_; | size_t input_arg_num_; | ||||
| size_t filter_arg_num_; | size_t filter_arg_num_; | ||||
| size_t output_arg_num_; | size_t output_arg_num_; | ||||
| uint8_t asymmetric_; | |||||
| uint8_t per_channel_; | uint8_t per_channel_; | ||||
| } ConvQuantArg; | } ConvQuantArg; | ||||
| @@ -282,7 +283,11 @@ void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, | |||||
| void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, int *right_shift); | void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, int *right_shift); | ||||
| void QuantizeRoundParameter(double double_multiplier, int32_t *quantized_multiplier, int *left_shift, int *right_shift); | |||||
| void QuantizeRoundParameterWithDoublePrecision(double double_multiplier, int32_t *quantized_multiplier, int *left_shift, | |||||
| int *right_shift); | |||||
| void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t *quantized_multiplier, int *left_shift, | |||||
| int *right_shift); | |||||
| uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp); | uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp); | ||||
| @@ -40,6 +40,8 @@ table QuantParam { | |||||
| varCorr: float = 1; | varCorr: float = 1; | ||||
| meanCorr: float = 0; | meanCorr: float = 0; | ||||
| dstDtype: int = 32; | dstDtype: int = 32; | ||||
| roundType: int = 1; | |||||
| multiplier: int = -1; // calculate fixed point multiplier method | |||||
| } | } | ||||
| table Tensor { | table Tensor { | ||||
| @@ -69,6 +69,9 @@ void LiteSession::ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lit | |||||
| quant_arg.var_corr = quant_params->Get(j)->varCorr(); | quant_arg.var_corr = quant_params->Get(j)->varCorr(); | ||||
| quant_arg.mean_corr = quant_params->Get(j)->meanCorr(); | quant_arg.mean_corr = quant_params->Get(j)->meanCorr(); | ||||
| quant_arg.inited = quant_params->Get(j)->inited(); | quant_arg.inited = quant_params->Get(j)->inited(); | ||||
| quant_arg.roundType = quant_params->Get(j)->roundType(); | |||||
| quant_arg.multiplier = quant_params->Get(j)->multiplier(); | |||||
| quant_arg.dstDtype = quant_params->Get(j)->dstDtype(); | |||||
| dst_tensor->AddQuantParam(quant_arg); | dst_tensor->AddQuantParam(quant_arg); | ||||
| } | } | ||||
| } | } | ||||
| @@ -261,12 +261,43 @@ int ConvolutionBaseCPUKernel::SetQuantMultiplier() { | |||||
| static_cast<double>(conv_quant_arg_->input_quant_args_[0].scale_ * conv_quant_arg_->filter_quant_args_[i].scale_); | static_cast<double>(conv_quant_arg_->input_quant_args_[0].scale_ * conv_quant_arg_->filter_quant_args_[i].scale_); | ||||
| double real_multiplier = in_scale / static_cast<double>(conv_quant_arg_->output_quant_args_[0].scale_); | double real_multiplier = in_scale / static_cast<double>(conv_quant_arg_->output_quant_args_[0].scale_); | ||||
| conv_quant_arg_->real_multiplier_[i] = real_multiplier; | conv_quant_arg_->real_multiplier_[i] = real_multiplier; | ||||
| QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], &conv_quant_arg_->left_shift_[i], | |||||
| &conv_quant_arg_->right_shift_[i]); | |||||
| if (conv_quant_arg_->quant_multiplier_mode_ == Method_SinglePrecision) { | |||||
| QuantizeRoundParameterWithSinglePrecision(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], | |||||
| &conv_quant_arg_->left_shift_[i], &conv_quant_arg_->right_shift_[i]); | |||||
| } else if (conv_quant_arg_->quant_multiplier_mode_ == Method_DoublePrecision) { | |||||
| QuantizeRoundParameterWithDoublePrecision(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], | |||||
| &conv_quant_arg_->left_shift_[i], &conv_quant_arg_->right_shift_[i]); | |||||
| } | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void ConvolutionBaseCPUKernel::SetRoundingAndMultipilerMode() { | |||||
| auto input_quant_arg = in_tensors_.at(kInputIndex)->quant_params().front(); | |||||
| int round_type = input_quant_arg.roundType; | |||||
| switch (round_type) { | |||||
| case 1: | |||||
| conv_quant_arg_->round_mode_ = Rounding_Away_from_zero; | |||||
| break; | |||||
| case 2: | |||||
| conv_quant_arg_->round_mode_ = Rounding_Up; | |||||
| break; | |||||
| default: | |||||
| conv_quant_arg_->round_mode_ = Rounding_No; | |||||
| } | |||||
| int cal_multiplier_type = input_quant_arg.multiplier; | |||||
| switch (cal_multiplier_type) { | |||||
| case 0: | |||||
| conv_quant_arg_->quant_multiplier_mode_ = Method_SinglePrecision; | |||||
| break; | |||||
| case 1: | |||||
| conv_quant_arg_->quant_multiplier_mode_ = Method_DoublePrecision; | |||||
| break; | |||||
| default: | |||||
| conv_quant_arg_->quant_multiplier_mode_ = Method_No; | |||||
| } | |||||
| } | |||||
| int ConvolutionBaseCPUKernel::SetQuantParam() { | int ConvolutionBaseCPUKernel::SetQuantParam() { | ||||
| auto ret = MallocQuantParam(); | auto ret = MallocQuantParam(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -288,13 +319,12 @@ int ConvolutionBaseCPUKernel::SetQuantParam() { | |||||
| MS_LOG(ERROR) << "Set Output Tensor Quant Param Failed."; | MS_LOG(ERROR) << "Set Output Tensor Quant Param Failed."; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| ret = SetIfPerChannel(); | ret = SetIfPerChannel(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Set if per tensor channel failed."; | MS_LOG(ERROR) << "Set if per tensor channel failed."; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| SetRoundingAndMultipilerMode(); | |||||
| ret = SetQuantMultiplier(); | ret = SetQuantMultiplier(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Set Quant Multiplier Failed."; | MS_LOG(ERROR) << "Set Quant Multiplier Failed."; | ||||
| @@ -53,6 +53,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel { | |||||
| int SetFilterTensorQuantParam(); | int SetFilterTensorQuantParam(); | ||||
| int SetOutputTensorQuantParam(); | int SetOutputTensorQuantParam(); | ||||
| int SetQuantMultiplier(); | int SetQuantMultiplier(); | ||||
| void SetRoundingAndMultipilerMode(); | |||||
| int CheckResizeValid(); | int CheckResizeValid(); | ||||
| void FreeQuantParam(); | void FreeQuantParam(); | ||||
| @@ -120,8 +120,12 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) { | |||||
| ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, | ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, | ||||
| quant_arg.zeroPoint, num_unit_thread); | quant_arg.zeroPoint, num_unit_thread); | ||||
| } else if (src_dtype == TypeId::kNumberTypeFloat32 && dst_dtype == TypeId::kNumberTypeInt8) { | } else if (src_dtype == TypeId::kNumberTypeFloat32 && dst_dtype == TypeId::kNumberTypeInt8) { | ||||
| bool from_uint8_src = false; | |||||
| if (quant_arg.dstDtype == TypeId::kNumberTypeUInt8) { | |||||
| from_uint8_src = true; | |||||
| } | |||||
| ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, | ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, | ||||
| quant_arg.zeroPoint, num_unit_thread); | |||||
| quant_arg.zeroPoint, num_unit_thread, from_uint8_src); | |||||
| } else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeUInt8) { | } else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeUInt8) { | ||||
| ret = Int8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread); | ret = Int8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread); | ||||
| } else if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeFloat32) { | } else if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeFloat32) { | ||||
| @@ -138,8 +142,12 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) { | |||||
| input_quant_arg.scale, input_quant_arg.zeroPoint); | input_quant_arg.scale, input_quant_arg.zeroPoint); | ||||
| if (ret) { | if (ret) { | ||||
| auto output_quant_arg = out_tensors_.front()->quant_params().front(); | auto output_quant_arg = out_tensors_.front()->quant_params().front(); | ||||
| bool from_uint8_src = false; | |||||
| if (quant_arg.dstDtype == TypeId::kNumberTypeUInt8) { | |||||
| from_uint8_src = true; | |||||
| } | |||||
| ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_out_ptr_ + thread_offset, output_quant_arg.scale, | ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_out_ptr_ + thread_offset, output_quant_arg.scale, | ||||
| output_quant_arg.zeroPoint, num_unit_thread); | |||||
| output_quant_arg.zeroPoint, num_unit_thread, from_uint8_src); | |||||
| } | } | ||||
| } | } | ||||
| @@ -254,8 +254,8 @@ int ConvolutionDepthwiseSWInt8CPUKernel::ReinitQuantParam() { | |||||
| const double in_scale = static_cast<double>(input_scale_[i] * weight_scale_[i]); | const double in_scale = static_cast<double>(input_scale_[i] * weight_scale_[i]); | ||||
| double real_multiplier = in_scale / static_cast<double>(output_scale_[i]); | double real_multiplier = in_scale / static_cast<double>(output_scale_[i]); | ||||
| conv_quant_arg_->real_multiplier_[i] = real_multiplier; | conv_quant_arg_->real_multiplier_[i] = real_multiplier; | ||||
| QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], &conv_quant_arg_->left_shift_[i], | |||||
| &conv_quant_arg_->right_shift_[i]); | |||||
| QuantizeRoundParameterWithDoublePrecision(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], | |||||
| &conv_quant_arg_->left_shift_[i], &conv_quant_arg_->right_shift_[i]); | |||||
| } | } | ||||
| // now only consider per tensor for output | // now only consider per tensor for output | ||||
| @@ -132,8 +132,8 @@ int FullconnectionInt8CPUKernel::Init() { | |||||
| for (int i = 0; i < weight_quant_num; ++i) { | for (int i = 0; i < weight_quant_num; ++i) { | ||||
| const double in_scale = static_cast<double>(quant_.input_.scale_ * quant_.filter_scale_[i]); | const double in_scale = static_cast<double>(quant_.input_.scale_ * quant_.filter_scale_[i]); | ||||
| double real_multiplier = in_scale / static_cast<double>(quant_.output_.scale_); | double real_multiplier = in_scale / static_cast<double>(quant_.output_.scale_); | ||||
| QuantizeRoundParameter(real_multiplier, &quant_.quant_multiplier_[i], &quant_.left_shift_[i], | |||||
| &quant_.right_shift_[i]); | |||||
| QuantizeRoundParameterWithDoublePrecision(real_multiplier, &quant_.quant_multiplier_[i], &quant_.left_shift_[i], | |||||
| &quant_.right_shift_[i]); | |||||
| } | } | ||||
| CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6, | CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6, | ||||
| @@ -138,8 +138,8 @@ int MatmulInt8CPUKernel::ReSize() { | |||||
| } | } | ||||
| } | } | ||||
| double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_; | double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_; | ||||
| QuantizeRoundParameter(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift, | |||||
| &quant_params_.right_shift); | |||||
| QuantizeRoundParameterWithDoublePrecision(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift, | |||||
| &quant_params_.right_shift); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -39,7 +39,8 @@ int ReluXInt8CPUKernel::Init() { | |||||
| quant_arg_.output_arg.zp_ = output->quant_params().front().zeroPoint; | quant_arg_.output_arg.zp_ = output->quant_params().front().zeroPoint; | ||||
| const double multiplier = quant_arg_.input_arg.scale_ / quant_arg_.output_arg.scale_; | const double multiplier = quant_arg_.input_arg.scale_ / quant_arg_.output_arg.scale_; | ||||
| QuantizeRoundParameter(multiplier, &quant_arg_.input_multiplier_, &quant_arg_.left_shift_, &quant_arg_.right_shift_); | |||||
| QuantizeRoundParameterWithDoublePrecision(multiplier, &quant_arg_.input_multiplier_, &quant_arg_.left_shift_, | |||||
| &quant_arg_.right_shift_); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -86,8 +86,8 @@ int ResizeInt8CPUKernel::Init() { | |||||
| quant_out_->zp_ = output->quant_params().front().zeroPoint; | quant_out_->zp_ = output->quant_params().front().zeroPoint; | ||||
| quant_out_->scale_ = output->quant_params().front().scale; | quant_out_->scale_ = output->quant_params().front().scale; | ||||
| QuantizeRoundParameter(quant_in_->scale_ / quant_out_->scale_, &multiplier_->multiplier_, &multiplier_->left_shift_, | |||||
| &multiplier_->right_shift_); | |||||
| QuantizeRoundParameterWithDoublePrecision(quant_in_->scale_ / quant_out_->scale_, &multiplier_->multiplier_, | |||||
| &multiplier_->left_shift_, &multiplier_->right_shift_); | |||||
| if (!InferShapeDone()) { | if (!InferShapeDone()) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -38,6 +38,9 @@ struct QuantArg { | |||||
| bool inited; | bool inited; | ||||
| std::vector<float> clusters{}; | std::vector<float> clusters{}; | ||||
| int bitNum; | int bitNum; | ||||
| int roundType; | |||||
| int multiplier; | |||||
| int dstDtype; | |||||
| }; | }; | ||||
| class Tensor : public mindspore::tensor::MSTensor { | class Tensor : public mindspore::tensor::MSTensor { | ||||
| @@ -118,7 +118,7 @@ TEST_F(TestMatmulInt8, simple) { | |||||
| int a_sums[ROW4] = {0}; | int a_sums[ROW4] = {0}; | ||||
| int bias[COL4] = {0}; | int bias[COL4] = {0}; | ||||
| int multiplier, ls, rs; | int multiplier, ls, rs; | ||||
| QuantizeRoundParameter(1.0f, &multiplier, &ls, &rs); | |||||
| QuantizeRoundParameterWithDoublePrecision(1.0f, &multiplier, &ls, &rs); | |||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, &multiplier, &ls, | MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, &multiplier, &ls, | ||||
| &rs, ROW, COL, COL, false); | &rs, ROW, COL, COL, false); | ||||
| @@ -121,6 +121,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me | |||||
| std::make_unique<schema::QuantParamT>(input_quant_param); | std::make_unique<schema::QuantParamT>(input_quant_param); | ||||
| MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale | MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale | ||||
| << " zp: " << input_quant_param_ptr->zeroPoint; | << " zp: " << input_quant_param_ptr->zeroPoint; | ||||
| input_quant_param_ptr->dstDtype = tensor_input->dataType; | |||||
| tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); | tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -151,6 +152,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me | |||||
| std::make_unique<schema::QuantParamT>(channel_quant_param); | std::make_unique<schema::QuantParamT>(channel_quant_param); | ||||
| MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale | MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale | ||||
| << " zp: " << output_quant_param_ptr->zeroPoint; | << " zp: " << output_quant_param_ptr->zeroPoint; | ||||
| output_quant_param_ptr->dstDtype = output_tensor->dataType; | |||||
| output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr)); | output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -258,6 +260,9 @@ int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::uniqu | |||||
| auto subgraph_name = func_graph->get_attr("graph_name"); | auto subgraph_name = func_graph->get_attr("graph_name"); | ||||
| MS_ASSERT(subgraph_name != nullptr); | MS_ASSERT(subgraph_name != nullptr); | ||||
| sub_graphT->name = GetValue<std::string>(subgraph_name); | sub_graphT->name = GetValue<std::string>(subgraph_name); | ||||
| auto fmk = func_graph->get_attr("fmk"); | |||||
| MS_ASSERT(fmk != nullptr); | |||||
| meta_graphT->fmkType = GetValue<int>(fmk); | |||||
| auto cnodes = func_graph->GetOrderedCnodes(); | auto cnodes = func_graph->GetOrderedCnodes(); | ||||
| for (const auto &cnode : cnodes) { | for (const auto &cnode : cnodes) { | ||||
| @@ -448,6 +448,8 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si | |||||
| toAddTensor->dataType = prim->dstT; | toAddTensor->dataType = prim->dstT; | ||||
| if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) { | if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) { | ||||
| preTensor->quantParams.front()->zeroPoint += 128; | preTensor->quantParams.front()->zeroPoint += 128; | ||||
| } else if (prim->srcT == TypeId::kNumberTypeInt8 && prim->dstT == TypeId::kNumberTypeUInt8) { | |||||
| toAddTensor->quantParams.front()->zeroPoint += 128; | |||||
| } | } | ||||
| } | } | ||||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | graphT->allTensors.emplace_back(std::move(toAddTensor)); | ||||
| @@ -491,6 +493,8 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si | |||||
| toAddTensor->dataType = prim->dstT; | toAddTensor->dataType = prim->dstT; | ||||
| if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) { | if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) { | ||||
| preTensor->quantParams.front()->zeroPoint += 128; | preTensor->quantParams.front()->zeroPoint += 128; | ||||
| } else if (prim->srcT == TypeId::kNumberTypeInt8 && prim->dstT == TypeId::kNumberTypeUInt8) { | |||||
| toAddTensor->quantParams.front()->zeroPoint += 128; | |||||
| } | } | ||||
| } | } | ||||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | graphT->allTensors.emplace_back(std::move(toAddTensor)); | ||||
| @@ -552,8 +556,10 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz | |||||
| MS_ASSERT(prim != nullptr); | MS_ASSERT(prim != nullptr); | ||||
| postTensor->dataType = prim->srcT; | postTensor->dataType = prim->srcT; | ||||
| toAddTensor->dataType = prim->dstT; | toAddTensor->dataType = prim->dstT; | ||||
| if (prim->dstT == TypeId::kNumberTypeUInt8 && prim->srcT == TypeId::kNumberTypeInt8) { | |||||
| if (prim->srcT == TypeId::kNumberTypeInt8 && prim->dstT == TypeId::kNumberTypeUInt8) { | |||||
| toAddTensor->quantParams.front()->zeroPoint += 128; | toAddTensor->quantParams.front()->zeroPoint += 128; | ||||
| } else if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) { | |||||
| postTensor->quantParams.front()->zeroPoint += 128; | |||||
| } | } | ||||
| } | } | ||||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | graphT->allTensors.emplace_back(std::move(toAddTensor)); | ||||
| @@ -624,6 +630,8 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz | |||||
| toAddTensor->dataType = prim->dstT; | toAddTensor->dataType = prim->dstT; | ||||
| if (prim->dstT == TypeId::kNumberTypeUInt8 && prim->srcT == TypeId::kNumberTypeInt8) { | if (prim->dstT == TypeId::kNumberTypeUInt8 && prim->srcT == TypeId::kNumberTypeInt8) { | ||||
| toAddTensor->quantParams.front()->zeroPoint += 128; | toAddTensor->quantParams.front()->zeroPoint += 128; | ||||
| } else if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) { | |||||
| postTensor->quantParams.front()->zeroPoint += 128; | |||||
| } | } | ||||
| } | } | ||||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | graphT->allTensors.emplace_back(std::move(toAddTensor)); | ||||
| @@ -38,6 +38,8 @@ std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schem | |||||
| dstQuantParam->max = srcQuantParam->max; | dstQuantParam->max = srcQuantParam->max; | ||||
| dstQuantParam->narrowRange = srcQuantParam->narrowRange; | dstQuantParam->narrowRange = srcQuantParam->narrowRange; | ||||
| dstQuantParam->numBits = srcQuantParam->numBits; | dstQuantParam->numBits = srcQuantParam->numBits; | ||||
| dstQuantParam->dstDtype = srcQuantParam->dstDtype; | |||||
| dstQuantParam->multiplier = srcQuantParam->multiplier; | |||||
| return dstQuantParam; | return dstQuantParam; | ||||
| } | } | ||||
| @@ -71,6 +71,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| graph->set_attr("graph_name", MakeValue("main_graph")); | graph->set_attr("graph_name", MakeValue("main_graph")); | ||||
| graph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_MS))); | |||||
| } else { | } else { | ||||
| MS_ASSERT(nullptr != modelParser); | MS_ASSERT(nullptr != modelParser); | ||||
| const std::string modelFile = flag->modelFile; | const std::string modelFile = flag->modelFile; | ||||
| @@ -158,7 +158,7 @@ int Flags::Init(int argc, const char **argv) { | |||||
| return RET_INPUT_PARAM_INVALID; | return RET_INPUT_PARAM_INVALID; | ||||
| } | } | ||||
| if (this->trainModel == true) { | |||||
| if (this->trainModel) { | |||||
| if (this->fmk != FmkType_MS) { | if (this->fmk != FmkType_MS) { | ||||
| std::cerr << "INPUT ILLEGAL: train model convertor supporting only MINDIR format"; | std::cerr << "INPUT ILLEGAL: train model convertor supporting only MINDIR format"; | ||||
| return RET_INPUT_PARAM_INVALID; | return RET_INPUT_PARAM_INVALID; | ||||
| @@ -30,7 +30,14 @@ using mindspore::schema::QuantType_PostTraining; | |||||
| using mindspore::schema::QuantType_QUANT_NONE; | using mindspore::schema::QuantType_QUANT_NONE; | ||||
| using mindspore::schema::QuantType_WeightQuant; | using mindspore::schema::QuantType_WeightQuant; | ||||
| namespace converter { | namespace converter { | ||||
| enum FmkType { FmkType_TF = 0, FmkType_CAFFE = 1, FmkType_ONNX = 2, FmkType_MS = 3, FmkType_TFLITE = 4 }; | |||||
| enum FmkType { | |||||
| FmkType_TF = 0, | |||||
| FmkType_CAFFE = 1, | |||||
| FmkType_ONNX = 2, | |||||
| FmkType_MS = 3, | |||||
| FmkType_TFLITE = 4, | |||||
| FmkType_ONNX_LOW_VERSION = 5 | |||||
| }; | |||||
| class Flags : public virtual mindspore::lite::FlagParser { | class Flags : public virtual mindspore::lite::FlagParser { | ||||
| public: | public: | ||||
| @@ -161,7 +161,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { | |||||
| if (postTensor->dataType != TypeId::kNumberTypeInt8) { | if (postTensor->dataType != TypeId::kNumberTypeInt8) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| iter = InsertDTypeTransNode(graph, iter, kAfter, i, kNumberTypeFloat32, kNumberTypeInt8, &status); | |||||
| iter = InsertDTypeTransNode(graph, iter, kAfter, i, kNumberTypeFloat32, kNumberTypeUInt8, &status); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; | MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -25,7 +25,6 @@ STATUS SetUnusedQuantParamToDefaultPass::Run(schema::MetaGraphT *graph) { | |||||
| quant_param->min = 0; | quant_param->min = 0; | ||||
| quant_param->max = 0; | quant_param->max = 0; | ||||
| quant_param->narrowRange = true; | quant_param->narrowRange = true; | ||||
| quant_param->dstDtype = TypeId::kNumberTypeFloat32; | |||||
| } | } | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -44,7 +44,7 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| int index = -1; | |||||
| unsigned int index = -1; | |||||
| for (auto &tensor : graph->allTensors) { | for (auto &tensor : graph->allTensors) { | ||||
| index++; | index++; | ||||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | ||||
| @@ -59,7 +59,8 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { | |||||
| auto &quantParam = tensor->quantParams.front(); | auto &quantParam = tensor->quantParams.front(); | ||||
| size_t wShapeSize = tensor->data.empty() ? 0 : GetShapeSize(*(tensor.get())); | size_t wShapeSize = tensor->data.empty() ? 0 : GetShapeSize(*(tensor.get())); | ||||
| void *oriWeightData = tensor->data.data(); | void *oriWeightData = tensor->data.data(); | ||||
| if (quantParam->dstDtype == TypeId::kNumberTypeInt8) { | |||||
| if (quantParam->dstDtype == TypeId::kNumberTypeUInt8 || quantParam->dstDtype == TypeId::kNumberTypeFloat32 || | |||||
| quantParam->dstDtype == TypeId::kNumberTypeFloat) { | |||||
| std::vector<int8_t> qDatas(wShapeSize); | std::vector<int8_t> qDatas(wShapeSize); | ||||
| auto weightQauntParam = GetTensorQuantParam(tensor); | auto weightQauntParam = GetTensorQuantParam(tensor); | ||||
| if (tensor->dataType == TypeId::kNumberTypeFloat || | if (tensor->dataType == TypeId::kNumberTypeFloat || | ||||
| @@ -71,7 +72,7 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { | |||||
| for (size_t j = 0; j < wShapeSize; j++) { | for (size_t j = 0; j < wShapeSize; j++) { | ||||
| qDatas[j] = quant::QuantizeData<int8_t>(weightData[j], weightQauntParam.get()); | qDatas[j] = quant::QuantizeData<int8_t>(weightData[j], weightQauntParam.get()); | ||||
| } | } | ||||
| } else { // tflite awareing quant | |||||
| } else { // convert uint8 to int8 | |||||
| auto *weightData = static_cast<uint8_t *>(oriWeightData); | auto *weightData = static_cast<uint8_t *>(oriWeightData); | ||||
| for (size_t j = 0; j < wShapeSize; j++) { | for (size_t j = 0; j < wShapeSize; j++) { | ||||
| qDatas[j] = (int32_t)weightData[j] - 128; | qDatas[j] = (int32_t)weightData[j] - 128; | ||||
| @@ -55,6 +55,7 @@ FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::s | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph")); | func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph")); | ||||
| func_graph_ptr_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_CAFFE))); | |||||
| return func_graph_ptr_; | return func_graph_ptr_; | ||||
| } | } | ||||
| @@ -40,6 +40,7 @@ std::set<std::string> SPECIAL_NODE = {"Gemm"}; | |||||
| FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file, | FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file, | ||||
| const QuantType &quant_type) { | const QuantType &quant_type) { | ||||
| NoSupportOp::GetInstance()->SetFmkType("ONNX"); | NoSupportOp::GetInstance()->SetFmkType("ONNX"); | ||||
| anf_root_graph_ = std::make_shared<FuncGraph>(); | |||||
| auto status = InitOriginModel(model_file); | auto status = InitOriginModel(model_file); | ||||
| if (RET_OK != status) { | if (RET_OK != status) { | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| @@ -47,7 +48,6 @@ FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::st | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| anf_root_graph_ = std::make_shared<FuncGraph>(); | |||||
| status = ConvertOnnxGraph(onnx_root_graph_, anf_root_graph_, &anf_nodes_map_, {}, "root_node"); | status = ConvertOnnxGraph(onnx_root_graph_, anf_root_graph_, &anf_nodes_map_, {}, "root_node"); | ||||
| if (RET_OK != status) { | if (RET_OK != status) { | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| @@ -77,6 +77,11 @@ STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) { | |||||
| } | } | ||||
| OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version()); | OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version()); | ||||
| onnx_root_graph_ = onnx_model_.graph(); | onnx_root_graph_ = onnx_model_.graph(); | ||||
| if (OnnxNodeParser::opset_version() > 15) { | |||||
| anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX))); | |||||
| } else { | |||||
| anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX_LOW_VERSION))); | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS OnnxModelParser::ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph, | STATUS OnnxModelParser::ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph, | ||||
| @@ -614,6 +619,9 @@ STATUS OnnxModelParser::SetTensorQuantParamFromNode(const std::string &tensor_na | |||||
| std::vector<QuantParamT> *quant_params) { | std::vector<QuantParamT> *quant_params) { | ||||
| quant_params->clear(); | quant_params->clear(); | ||||
| auto quant_param = std::make_unique<QuantParamT>(); | auto quant_param = std::make_unique<QuantParamT>(); | ||||
| if (OnnxNodeParser::opset_version() <= 15) { | |||||
| quant_param->multiplier = 0; | |||||
| } | |||||
| std::string quant_tensor_name = "scale_" + tensor_name; | std::string quant_tensor_name = "scale_" + tensor_name; | ||||
| auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), true); | auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), true); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| @@ -366,6 +366,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| anf_root_graph_->set_attr("graph_name", MakeValue("main_graph")); | anf_root_graph_->set_attr("graph_name", MakeValue("main_graph")); | ||||
| anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF))); | |||||
| for (int i = 0; i < tf_root_graph_->node_size(); i++) { | for (int i = 0; i < tf_root_graph_->node_size(); i++) { | ||||
| auto &node_def = tf_root_graph_->node(i); | auto &node_def = tf_root_graph_->node(i); | ||||
| @@ -441,6 +442,7 @@ STATUS TFModelParser::ConvertSubgraph() { | |||||
| FuncGraphPtr sub_func_graph = std::make_shared<FuncGraph>(); | FuncGraphPtr sub_func_graph = std::make_shared<FuncGraph>(); | ||||
| sub_func_graph->set_attr("graph_name", MakeValue(sub_graph_name)); | sub_func_graph->set_attr("graph_name", MakeValue(sub_graph_name)); | ||||
| sub_func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF))); | |||||
| std::unordered_map<std::string, AnfNodePtr> anf_sub_node_map; | std::unordered_map<std::string, AnfNodePtr> anf_sub_node_map; | ||||
| // convert sub graph inputs | // convert sub graph inputs | ||||
| std::vector<ParameterPtr> sub_graph_inputs; | std::vector<ParameterPtr> sub_graph_inputs; | ||||
| @@ -55,6 +55,7 @@ FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std:: | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| func_graph_ = std::make_shared<FuncGraph>(); | func_graph_ = std::make_shared<FuncGraph>(); | ||||
| func_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TFLITE))); | |||||
| auto status = ConvertGraphInputs(); | auto status = ConvertGraphInputs(); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| @@ -183,7 +184,7 @@ STATUS TfliteModelParser::ConvertOps() { | |||||
| } | } | ||||
| STATUS TfliteModelParser::SetTensorQuantParam(const tflite::TensorT *tflite_tensor, | STATUS TfliteModelParser::SetTensorQuantParam(const tflite::TensorT *tflite_tensor, | ||||
| std::vector<QuantParamT> *quant_params) { | |||||
| std::vector<QuantParamT> *quant_params, int round_type) { | |||||
| if (tflite_tensor == nullptr) { | if (tflite_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "tflite_tensor is null, set tensor quant params failed."; | MS_LOG(ERROR) << "tflite_tensor is null, set tensor quant params failed."; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -221,6 +222,8 @@ STATUS TfliteModelParser::SetTensorQuantParam(const tflite::TensorT *tflite_tens | |||||
| quant_param->max = tflite_tensor->quantization->max[i]; | quant_param->max = tflite_tensor->quantization->max[i]; | ||||
| } | } | ||||
| quant_param->inited = true; | quant_param->inited = true; | ||||
| quant_param->roundType = round_type; | |||||
| quant_param->multiplier = 1; | |||||
| quant_params->emplace_back(*std::move(quant_param)); | quant_params->emplace_back(*std::move(quant_param)); | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -236,6 +239,11 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, lite | |||||
| MS_LOG(ERROR) << "primitive_c is null, get quant params failed."; | MS_LOG(ERROR) << "primitive_c is null, get quant params failed."; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| int round_type = 1; | |||||
| if (primitive_c->primitiveT()->value.type == PrimitiveType_Conv2D) { | |||||
| round_type = 2; | |||||
| } | |||||
| const auto &tflite_subgraph = tflite_model_->subgraphs.front(); | const auto &tflite_subgraph = tflite_model_->subgraphs.front(); | ||||
| for (auto input_idx : op->inputs) { | for (auto input_idx : op->inputs) { | ||||
| if (input_idx < 0) { | if (input_idx < 0) { | ||||
| @@ -243,7 +251,7 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, lite | |||||
| } | } | ||||
| const auto &input_tensor = tflite_subgraph->tensors[input_idx]; | const auto &input_tensor = tflite_subgraph->tensors[input_idx]; | ||||
| std::vector<schema::QuantParamT> quant_params; | std::vector<schema::QuantParamT> quant_params; | ||||
| auto status = SetTensorQuantParam(input_tensor.get(), &quant_params); | |||||
| auto status = SetTensorQuantParam(input_tensor.get(), &quant_params, round_type); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "set input tensor quant param failed."; | MS_LOG(ERROR) << "set input tensor quant param failed."; | ||||
| return status; | return status; | ||||
| @@ -256,7 +264,7 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, lite | |||||
| } | } | ||||
| const auto &output_tensor = tflite_subgraph->tensors.at(output_idx); | const auto &output_tensor = tflite_subgraph->tensors.at(output_idx); | ||||
| std::vector<schema::QuantParamT> quant_params; | std::vector<schema::QuantParamT> quant_params; | ||||
| auto status = SetTensorQuantParam(output_tensor.get(), &quant_params); | |||||
| auto status = SetTensorQuantParam(output_tensor.get(), &quant_params, round_type); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "set output tensor quant param failed."; | MS_LOG(ERROR) << "set output tensor quant param failed."; | ||||
| return status; | return status; | ||||
| @@ -48,7 +48,8 @@ class TfliteModelParser : public ModelParser { | |||||
| STATUS ConvertOps(); | STATUS ConvertOps(); | ||||
| STATUS ConvertGraphInputs(); | STATUS ConvertGraphInputs(); | ||||
| STATUS ConvertGraphOutputs(); | STATUS ConvertGraphOutputs(); | ||||
| static STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector<QuantParamT> *quant_params); | |||||
| static STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector<QuantParamT> *quant_params, | |||||
| int round_type = 1); | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| #endif // LITE_TFLITE_MODEL_PARSER_H | #endif // LITE_TFLITE_MODEL_PARSER_H | ||||
| @@ -595,6 +595,8 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, stru | |||||
| quant_param.numBits = bit_num; | quant_param.numBits = bit_num; | ||||
| quant_param.narrowRange = false; | quant_param.narrowRange = false; | ||||
| quant_param.inited = true; | quant_param.inited = true; | ||||
| quant_param.roundType = 1; | |||||
| quant_param.multiplier = 1; | |||||
| std::vector<schema::QuantParamT> quant_params = {quant_param}; | std::vector<schema::QuantParamT> quant_params = {quant_param}; | ||||
| lite_primitive->AddInputQuantParam(quant_params); | lite_primitive->AddInputQuantParam(quant_params); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -612,6 +614,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct | |||||
| quant_param.numBits = bit_num; | quant_param.numBits = bit_num; | ||||
| quant_param.narrowRange = false; | quant_param.narrowRange = false; | ||||
| quant_param.inited = true; | quant_param.inited = true; | ||||
| quant_param.roundType = 1; | |||||
| quant_param.multiplier = 1; | |||||
| std::vector<schema::QuantParamT> quant_params = {quant_param}; | std::vector<schema::QuantParamT> quant_params = {quant_param}; | ||||
| lite_primitive->AddOutputQuantParam(quant_params); | lite_primitive->AddOutputQuantParam(quant_params); | ||||
| return RET_OK; | return RET_OK; | ||||