Browse Source

fix quantized rounding

tags/v1.2.0-rc1
fuzhiye 4 years ago
parent
commit
6d86efc1d8
36 changed files with 232 additions and 46 deletions
  1. +28
    -7
      mindspore/lite/nnacl/assembly/arm64/ConvDwInt8PostAlign4.S
  2. +12
    -3
      mindspore/lite/nnacl/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S
  3. +8
    -1
      mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c
  4. +2
    -1
      mindspore/lite/nnacl/int8/quant_dtype_cast_int8.h
  5. +6
    -0
      mindspore/lite/nnacl/op_base.h
  6. +14
    -1
      mindspore/lite/nnacl/quantization/fixed_point.c
  7. +5
    -0
      mindspore/lite/nnacl/quantization/fixed_point.h
  8. +26
    -2
      mindspore/lite/nnacl/quantization/quantize.c
  9. +7
    -2
      mindspore/lite/nnacl/quantization/quantize.h
  10. +2
    -0
      mindspore/lite/schema/model.fbs
  11. +3
    -0
      mindspore/lite/src/lite_session.cc
  12. +34
    -4
      mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc
  13. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h
  14. +10
    -2
      mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc
  15. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_slidewindow_int8.cc
  16. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc
  17. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc
  18. +2
    -1
      mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.cc
  19. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.cc
  20. +3
    -0
      mindspore/lite/src/tensor.h
  21. +1
    -1
      mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc
  22. +5
    -0
      mindspore/lite/tools/anf_exporter/anf_exporter.cc
  23. +9
    -1
      mindspore/lite/tools/common/graph_util.cc
  24. +2
    -0
      mindspore/lite/tools/common/tensor_util.cc
  25. +1
    -0
      mindspore/lite/tools/converter/converter.cc
  26. +1
    -1
      mindspore/lite/tools/converter/converter_flags.cc
  27. +8
    -1
      mindspore/lite/tools/converter/converter_flags.h
  28. +1
    -1
      mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc
  29. +0
    -1
      mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.cc
  30. +4
    -3
      mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc
  31. +1
    -0
      mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc
  32. +9
    -1
      mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc
  33. +2
    -0
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc
  34. +11
    -3
      mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc
  35. +2
    -1
      mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h
  36. +4
    -0
      mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc

+ 28
- 7
mindspore/lite/nnacl/assembly/arm64/ConvDwInt8PostAlign4.S View File

@@ -53,10 +53,22 @@ ConvDwInt8PostAlign4:
sqrdmulh v2.4s, v2.4s, v27.4s sqrdmulh v2.4s, v2.4s, v27.4s
sqrdmulh v3.4s, v3.4s, v27.4s sqrdmulh v3.4s, v3.4s, v27.4s


sqrshl v0.4s, v0.4s, v28.4s
sqrshl v1.4s, v1.4s, v28.4s
sqrshl v2.4s, v2.4s, v28.4s
sqrshl v3.4s, v3.4s, v28.4s
and v4.16b, v0.16b, v28.16b
sshr v4.4s, v4.4s, #31
sqadd v0.4s, v0.4s, v4.4s
srshl v0.4s, v0.4s, v28.4s
and v5.16b, v1.16b, v28.16b
sshr v5.4s, v5.4s, #31
sqadd v1.4s, v1.4s, v5.4s
srshl v1.4s, v1.4s, v28.4s
and v6.16b, v2.16b, v28.16b
sshr v6.4s, v6.4s, #31
sqadd v2.4s, v2.4s, v6.4s
srshl v2.4s, v2.4s, v28.4s
and v7.16b, v3.16b, v28.16b
sshr v7.4s, v7.4s, #31
sqadd v3.4s, v3.4s, v7.4s
srshl v3.4s, v3.4s, v28.4s


AddZpDepth16: AddZpDepth16:
add v0.4s, v0.4s, v29.4s add v0.4s, v0.4s, v29.4s
@@ -109,8 +121,14 @@ ConvDwInt8PostAlign4:
RightShiftDepth8: RightShiftDepth8:
sqrdmulh v0.4s, v0.4s, v27.4s sqrdmulh v0.4s, v0.4s, v27.4s
sqrdmulh v1.4s, v1.4s, v27.4s sqrdmulh v1.4s, v1.4s, v27.4s
sqrshl v0.4s, v0.4s, v28.4s
sqrshl v1.4s, v1.4s, v28.4s
and v4.16b, v0.16b, v28.16b
sshr v4.4s, v4.4s, #31
sqadd v0.4s, v0.4s, v4.4s
srshl v0.4s, v0.4s, v28.4s
and v5.16b, v1.16b, v28.16b
sshr v5.4s, v5.4s, #31
sqadd v1.4s, v1.4s, v5.4s
srshl v1.4s, v1.4s, v28.4s


AddZpDepth8: AddZpDepth8:
add v0.4s, v0.4s, v29.4s add v0.4s, v0.4s, v29.4s
@@ -140,7 +158,10 @@ ConvDwInt8PostAlign4:


sqshl v0.4s, v0.4s, v26.4s sqshl v0.4s, v0.4s, v26.4s
sqrdmulh v0.4s, v0.4s, v27.4s sqrdmulh v0.4s, v0.4s, v27.4s
sqrshl v0.4s, v0.4s, v28.4s
and v4.16b, v0.16b, v28.16b
sshr v4.4s, v4.4s, #31
sqadd v0.4s, v0.4s, v4.4s
srshl v0.4s, v0.4s, v28.4s


add v0.4s, v0.4s, v29.4s add v0.4s, v0.4s, v29.4s
smax v0.4s, v0.4s, v30.4s smax v0.4s, v0.4s, v30.4s


+ 12
- 3
mindspore/lite/nnacl/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S View File

@@ -43,8 +43,14 @@ ConvDwInt8PostAlign4PerChannel:


sqrdmulh v0.4s, v0.4s, v4.4s sqrdmulh v0.4s, v0.4s, v4.4s
sqrdmulh v1.4s, v1.4s, v5.4s sqrdmulh v1.4s, v1.4s, v5.4s
sqrshl v0.4s, v0.4s, v6.4s
sqrshl v1.4s, v1.4s, v7.4s
and v16.16b, v0.16b, v6.16b
sshr v16.4s, v16.4s, #31
sqadd v0.4s, v0.4s, v16.4s
srshl v0.4s, v0.4s, v6.4s
and v17.16b, v1.16b, v7.16b
sshr v17.4s, v17.4s, #31
sqadd v1.4s, v1.4s, v17.4s
srshl v1.4s, v1.4s, v7.4s


add v0.4s, v0.4s, v29.4s add v0.4s, v0.4s, v29.4s
add v1.4s, v1.4s, v29.4s add v1.4s, v1.4s, v29.4s
@@ -80,7 +86,10 @@ ConvDwInt8PostAlign4PerChannel:
sqrdmulh v0.4s, v0.4s, v4.4s sqrdmulh v0.4s, v0.4s, v4.4s


ld1 {v6.4s}, [x6], #16 ld1 {v6.4s}, [x6], #16
sqrshl v0.4s, v0.4s, v6.4s
and v16.16b, v0.16b, v6.16b
sshr v16.4s, v16.4s, #31
sqadd v0.4s, v0.4s, v16.4s
srshl v0.4s, v0.4s, v6.4s


add v0.4s, v0.4s, v29.4s add v0.4s, v0.4s, v29.4s
smax v0.4s, v0.4s, v30.4s smax v0.4s, v0.4s, v30.4s


+ 8
- 1
mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c View File

@@ -29,17 +29,24 @@ int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float
return NNACL_OK; return NNACL_OK;
} }


int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) {
int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size,
bool uint8_flag) {
if (quant_values == NULL || real_values == NULL) { if (quant_values == NULL || real_values == NULL) {
return NNACL_PARAM_INVALID; return NNACL_PARAM_INVALID;
} }


if (uint8_flag) {
zp += 128;
}
const float inverse_scale = 1.0f / scale; const float inverse_scale = 1.0f / scale;
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
if (isinf(real_values[i])) { if (isinf(real_values[i])) {
quant_values[i] = 127; quant_values[i] = 127;
} else { } else {
int temp = round(real_values[i] * inverse_scale + zp); int temp = round(real_values[i] * inverse_scale + zp);
if (uint8_flag) {
temp -= 128;
}
temp = temp < 127 ? temp : 127; temp = temp < 127 ? temp : 127;
temp = temp > -128 ? temp : -128; temp = temp > -128 ? temp : -128;
quant_values[i] = (int8_t)temp; quant_values[i] = (int8_t)temp;


+ 2
- 1
mindspore/lite/nnacl/int8/quant_dtype_cast_int8.h View File

@@ -29,7 +29,8 @@ typedef struct QuantDTypeCastParameter {
extern "C" { extern "C" {
#endif #endif
int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size);
int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size);
int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size,
bool uint8_flag);
int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size); int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size);
int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size); int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size);
int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size); int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size);


+ 6
- 0
mindspore/lite/nnacl/op_base.h View File

@@ -80,6 +80,12 @@ typedef struct OpParameter {


typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, ActType_Prelu } ActType; typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, ActType_Prelu } ActType;
typedef enum PadMode { Pad_No, Pad_Same, Pad_Valid } PadMode; typedef enum PadMode { Pad_No, Pad_Same, Pad_Valid } PadMode;
typedef enum RoundingMode { Rounding_No, Rounding_Away_from_zero, Rounding_Up } RoundingMode;
typedef enum CalFixedMultiplierMode {
Method_No,
Method_SinglePrecision,
Method_DoublePrecision
} CalFixedMultiplierMode;


#ifdef ENABLE_ARM #ifdef ENABLE_ARM
#define MS_FLOAT32X4 float32x4_t #define MS_FLOAT32X4 float32x4_t


+ 14
- 1
mindspore/lite/nnacl/quantization/fixed_point.c View File

@@ -42,7 +42,7 @@ int16_t SaturatingRoundingDoublingHighMulInt16(int16_t a, int16_t b) {
} }


// division by a 2^exponent with rounding // division by a 2^exponent with rounding
// or arithmetic right shift with rouding
// or arithmetic right shift with rounding
int RoundingDivideByPOT(int x, int exponent) { int RoundingDivideByPOT(int x, int exponent) {
const int mask = (1ll << exponent) - 1; const int mask = (1ll << exponent) - 1;
const int remainder = x & mask; const int remainder = x & mask;
@@ -50,10 +50,23 @@ int RoundingDivideByPOT(int x, int exponent) {
return (x >> exponent) + (remainder > threshold ? 1 : 0); return (x >> exponent) + (remainder > threshold ? 1 : 0);
} }


int UpwardRounding(int x, int exponent) {
const int32_t rounding_offset = (exponent > 0) ? (1 << (exponent - 1)) : 0;
if (x > INT32_MAX - rounding_offset) {
return 1 << (31 - exponent);
}
return (x + rounding_offset) >> exponent;
}

int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift) { int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift) {
return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift);
} }


int MultiplyByQuantizedMultiplierWithUpwardRounding(int32_t value, int32_t multiplier, int32_t left_shift,
int32_t right_shift) {
return UpwardRounding(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift);
}

int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift) { int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift) {
return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value, multiplier), right_shift); return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value, multiplier), right_shift);
} }


+ 5
- 0
mindspore/lite/nnacl/quantization/fixed_point.h View File

@@ -40,8 +40,13 @@ int16_t SaturatingRoundingDoublingHighMulInt16(int16_t a, int16_t b);
// or arithmetic right shift with rouding // or arithmetic right shift with rouding
int RoundingDivideByPOT(int x, int exponent); int RoundingDivideByPOT(int x, int exponent);


int UpwardRounding(int x, int exponent);

int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift); int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift);


int MultiplyByQuantizedMultiplierWithUpwardRounding(int32_t value, int32_t multiplier, int32_t left_shift,
int32_t right_shift);

int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift); int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift);


int SaturatingRoundingMultiplyByPOT(int32_t x, int exponent); int SaturatingRoundingMultiplyByPOT(int32_t x, int exponent);


+ 26
- 2
mindspore/lite/nnacl/quantization/quantize.c View File

@@ -15,6 +15,7 @@
*/ */


#include "nnacl/quantization/quantize.h" #include "nnacl/quantization/quantize.h"
#include <stdio.h>


const uint64_t dSignMask = 1ull << 63; const uint64_t dSignMask = 1ull << 63;
const uint64_t dExponentMask = 0x7ffull << 52; const uint64_t dExponentMask = 0x7ffull << 52;
@@ -35,8 +36,8 @@ void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantiz
*right_shift = -shift; *right_shift = -shift;
} }


void QuantizeRoundParameter(double double_multiplier, int32_t *quantized_multiplier, int *left_shift,
int *right_shift) {
void QuantizeRoundParameterWithDoublePrecision(double double_multiplier, int32_t *quantized_multiplier, int *left_shift,
int *right_shift) {
int shift = 0; int shift = 0;
QuantizeMultiplierSmallerThanOne(double_multiplier, quantized_multiplier, &shift); QuantizeMultiplierSmallerThanOne(double_multiplier, quantized_multiplier, &shift);
shift = -shift; shift = -shift;
@@ -49,6 +50,29 @@ void QuantizeRoundParameter(double double_multiplier, int32_t *quantized_multipl
} }
} }


void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t *quantized_multiplier, int *left_shift,
int *right_shift) {
int shift = 0;
const uint32_t scale_bits = (uint32_t)(double_multiplier);
/* multipiler is in[0x40000000, 0x7FFFFF80] range */
*quantized_multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
if (quantized_multiplier[0] < INT32_C(0x40000000) || quantized_multiplier[0] > INT32_C(0x7FFFFF80)) {
printf("quantized multiplier must be in [0x40000000, 0x7FFFFF80] range, now multiplier is %d\n",
quantized_multiplier[0]);
return;
}
/* shift is in [0, 31] range */
shift = 127 + 31 - 32 - ((uint32_t)(double_multiplier) >> 23);
shift = -shift;
if (shift < 0) {
*left_shift = 0;
*right_shift = shift;
} else {
*left_shift = shift;
*right_shift = 0;
}
}

uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); }


int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); }


+ 7
- 2
mindspore/lite/nnacl/quantization/quantize.h View File

@@ -34,6 +34,8 @@ typedef struct QuantArg {
} QuantArg; } QuantArg;


typedef struct ConvQuantArg { typedef struct ConvQuantArg {
RoundingMode round_mode_;
CalFixedMultiplierMode quant_multiplier_mode_;
QuantArg *input_quant_args_; QuantArg *input_quant_args_;
QuantArg *filter_quant_args_; QuantArg *filter_quant_args_;
QuantArg *output_quant_args_; QuantArg *output_quant_args_;
@@ -46,7 +48,6 @@ typedef struct ConvQuantArg {
size_t input_arg_num_; size_t input_arg_num_;
size_t filter_arg_num_; size_t filter_arg_num_;
size_t output_arg_num_; size_t output_arg_num_;
uint8_t asymmetric_;
uint8_t per_channel_; uint8_t per_channel_;
} ConvQuantArg; } ConvQuantArg;


@@ -282,7 +283,11 @@ void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier,


void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, int *right_shift); void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, int *right_shift);


void QuantizeRoundParameter(double double_multiplier, int32_t *quantized_multiplier, int *left_shift, int *right_shift);
void QuantizeRoundParameterWithDoublePrecision(double double_multiplier, int32_t *quantized_multiplier, int *left_shift,
int *right_shift);

void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t *quantized_multiplier, int *left_shift,
int *right_shift);


uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp); uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp);




+ 2
- 0
mindspore/lite/schema/model.fbs View File

@@ -40,6 +40,8 @@ table QuantParam {
varCorr: float = 1; varCorr: float = 1;
meanCorr: float = 0; meanCorr: float = 0;
dstDtype: int = 32; dstDtype: int = 32;
roundType: int = 1;
multiplier: int = -1; // calculate fixed point multiplier method
} }


table Tensor { table Tensor {


+ 3
- 0
mindspore/lite/src/lite_session.cc View File

@@ -69,6 +69,9 @@ void LiteSession::ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lit
quant_arg.var_corr = quant_params->Get(j)->varCorr(); quant_arg.var_corr = quant_params->Get(j)->varCorr();
quant_arg.mean_corr = quant_params->Get(j)->meanCorr(); quant_arg.mean_corr = quant_params->Get(j)->meanCorr();
quant_arg.inited = quant_params->Get(j)->inited(); quant_arg.inited = quant_params->Get(j)->inited();
quant_arg.roundType = quant_params->Get(j)->roundType();
quant_arg.multiplier = quant_params->Get(j)->multiplier();
quant_arg.dstDtype = quant_params->Get(j)->dstDtype();
dst_tensor->AddQuantParam(quant_arg); dst_tensor->AddQuantParam(quant_arg);
} }
} }


+ 34
- 4
mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc View File

@@ -261,12 +261,43 @@ int ConvolutionBaseCPUKernel::SetQuantMultiplier() {
static_cast<double>(conv_quant_arg_->input_quant_args_[0].scale_ * conv_quant_arg_->filter_quant_args_[i].scale_); static_cast<double>(conv_quant_arg_->input_quant_args_[0].scale_ * conv_quant_arg_->filter_quant_args_[i].scale_);
double real_multiplier = in_scale / static_cast<double>(conv_quant_arg_->output_quant_args_[0].scale_); double real_multiplier = in_scale / static_cast<double>(conv_quant_arg_->output_quant_args_[0].scale_);
conv_quant_arg_->real_multiplier_[i] = real_multiplier; conv_quant_arg_->real_multiplier_[i] = real_multiplier;
QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], &conv_quant_arg_->left_shift_[i],
&conv_quant_arg_->right_shift_[i]);
if (conv_quant_arg_->quant_multiplier_mode_ == Method_SinglePrecision) {
QuantizeRoundParameterWithSinglePrecision(real_multiplier, &conv_quant_arg_->quant_multiplier_[i],
&conv_quant_arg_->left_shift_[i], &conv_quant_arg_->right_shift_[i]);
} else if (conv_quant_arg_->quant_multiplier_mode_ == Method_DoublePrecision) {
QuantizeRoundParameterWithDoublePrecision(real_multiplier, &conv_quant_arg_->quant_multiplier_[i],
&conv_quant_arg_->left_shift_[i], &conv_quant_arg_->right_shift_[i]);
}
} }
return RET_OK; return RET_OK;
} }


void ConvolutionBaseCPUKernel::SetRoundingAndMultipilerMode() {
auto input_quant_arg = in_tensors_.at(kInputIndex)->quant_params().front();
int round_type = input_quant_arg.roundType;
switch (round_type) {
case 1:
conv_quant_arg_->round_mode_ = Rounding_Away_from_zero;
break;
case 2:
conv_quant_arg_->round_mode_ = Rounding_Up;
break;
default:
conv_quant_arg_->round_mode_ = Rounding_No;
}
int cal_multiplier_type = input_quant_arg.multiplier;
switch (cal_multiplier_type) {
case 0:
conv_quant_arg_->quant_multiplier_mode_ = Method_SinglePrecision;
break;
case 1:
conv_quant_arg_->quant_multiplier_mode_ = Method_DoublePrecision;
break;
default:
conv_quant_arg_->quant_multiplier_mode_ = Method_No;
}
}

int ConvolutionBaseCPUKernel::SetQuantParam() { int ConvolutionBaseCPUKernel::SetQuantParam() {
auto ret = MallocQuantParam(); auto ret = MallocQuantParam();
if (ret != RET_OK) { if (ret != RET_OK) {
@@ -288,13 +319,12 @@ int ConvolutionBaseCPUKernel::SetQuantParam() {
MS_LOG(ERROR) << "Set Output Tensor Quant Param Failed."; MS_LOG(ERROR) << "Set Output Tensor Quant Param Failed.";
return ret; return ret;
} }

ret = SetIfPerChannel(); ret = SetIfPerChannel();
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Set if per tensor channel failed."; MS_LOG(ERROR) << "Set if per tensor channel failed.";
return ret; return ret;
} }
SetRoundingAndMultipilerMode();
ret = SetQuantMultiplier(); ret = SetQuantMultiplier();
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Set Quant Multiplier Failed."; MS_LOG(ERROR) << "Set Quant Multiplier Failed.";


+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h View File

@@ -53,6 +53,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
int SetFilterTensorQuantParam(); int SetFilterTensorQuantParam();
int SetOutputTensorQuantParam(); int SetOutputTensorQuantParam();
int SetQuantMultiplier(); int SetQuantMultiplier();
void SetRoundingAndMultipilerMode();
int CheckResizeValid(); int CheckResizeValid();
void FreeQuantParam(); void FreeQuantParam();




+ 10
- 2
mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc View File

@@ -120,8 +120,12 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) {
ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread); quant_arg.zeroPoint, num_unit_thread);
} else if (src_dtype == TypeId::kNumberTypeFloat32 && dst_dtype == TypeId::kNumberTypeInt8) { } else if (src_dtype == TypeId::kNumberTypeFloat32 && dst_dtype == TypeId::kNumberTypeInt8) {
bool from_uint8_src = false;
if (quant_arg.dstDtype == TypeId::kNumberTypeUInt8) {
from_uint8_src = true;
}
ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
quant_arg.zeroPoint, num_unit_thread, from_uint8_src);
} else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeUInt8) { } else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeUInt8) {
ret = Int8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread); ret = Int8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread);
} else if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeFloat32) { } else if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeFloat32) {
@@ -138,8 +142,12 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) {
input_quant_arg.scale, input_quant_arg.zeroPoint); input_quant_arg.scale, input_quant_arg.zeroPoint);
if (ret) { if (ret) {
auto output_quant_arg = out_tensors_.front()->quant_params().front(); auto output_quant_arg = out_tensors_.front()->quant_params().front();
bool from_uint8_src = false;
if (quant_arg.dstDtype == TypeId::kNumberTypeUInt8) {
from_uint8_src = true;
}
ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_out_ptr_ + thread_offset, output_quant_arg.scale, ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_out_ptr_ + thread_offset, output_quant_arg.scale,
output_quant_arg.zeroPoint, num_unit_thread);
output_quant_arg.zeroPoint, num_unit_thread, from_uint8_src);
} }
} }




+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_slidewindow_int8.cc View File

@@ -254,8 +254,8 @@ int ConvolutionDepthwiseSWInt8CPUKernel::ReinitQuantParam() {
const double in_scale = static_cast<double>(input_scale_[i] * weight_scale_[i]); const double in_scale = static_cast<double>(input_scale_[i] * weight_scale_[i]);
double real_multiplier = in_scale / static_cast<double>(output_scale_[i]); double real_multiplier = in_scale / static_cast<double>(output_scale_[i]);
conv_quant_arg_->real_multiplier_[i] = real_multiplier; conv_quant_arg_->real_multiplier_[i] = real_multiplier;
QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], &conv_quant_arg_->left_shift_[i],
&conv_quant_arg_->right_shift_[i]);
QuantizeRoundParameterWithDoublePrecision(real_multiplier, &conv_quant_arg_->quant_multiplier_[i],
&conv_quant_arg_->left_shift_[i], &conv_quant_arg_->right_shift_[i]);
} }


// now only consider per tensor for output // now only consider per tensor for output


+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc View File

@@ -132,8 +132,8 @@ int FullconnectionInt8CPUKernel::Init() {
for (int i = 0; i < weight_quant_num; ++i) { for (int i = 0; i < weight_quant_num; ++i) {
const double in_scale = static_cast<double>(quant_.input_.scale_ * quant_.filter_scale_[i]); const double in_scale = static_cast<double>(quant_.input_.scale_ * quant_.filter_scale_[i]);
double real_multiplier = in_scale / static_cast<double>(quant_.output_.scale_); double real_multiplier = in_scale / static_cast<double>(quant_.output_.scale_);
QuantizeRoundParameter(real_multiplier, &quant_.quant_multiplier_[i], &quant_.left_shift_[i],
&quant_.right_shift_[i]);
QuantizeRoundParameterWithDoublePrecision(real_multiplier, &quant_.quant_multiplier_[i], &quant_.left_shift_[i],
&quant_.right_shift_[i]);
} }


CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6, CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6,


+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc View File

@@ -138,8 +138,8 @@ int MatmulInt8CPUKernel::ReSize() {
} }
} }
double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_; double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_;
QuantizeRoundParameter(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift,
&quant_params_.right_shift);
QuantizeRoundParameterWithDoublePrecision(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift,
&quant_params_.right_shift);
return RET_OK; return RET_OK;
} }




+ 2
- 1
mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.cc View File

@@ -39,7 +39,8 @@ int ReluXInt8CPUKernel::Init() {
quant_arg_.output_arg.zp_ = output->quant_params().front().zeroPoint; quant_arg_.output_arg.zp_ = output->quant_params().front().zeroPoint;


const double multiplier = quant_arg_.input_arg.scale_ / quant_arg_.output_arg.scale_; const double multiplier = quant_arg_.input_arg.scale_ / quant_arg_.output_arg.scale_;
QuantizeRoundParameter(multiplier, &quant_arg_.input_multiplier_, &quant_arg_.left_shift_, &quant_arg_.right_shift_);
QuantizeRoundParameterWithDoublePrecision(multiplier, &quant_arg_.input_multiplier_, &quant_arg_.left_shift_,
&quant_arg_.right_shift_);


return RET_OK; return RET_OK;
} }


+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.cc View File

@@ -86,8 +86,8 @@ int ResizeInt8CPUKernel::Init() {
quant_out_->zp_ = output->quant_params().front().zeroPoint; quant_out_->zp_ = output->quant_params().front().zeroPoint;
quant_out_->scale_ = output->quant_params().front().scale; quant_out_->scale_ = output->quant_params().front().scale;


QuantizeRoundParameter(quant_in_->scale_ / quant_out_->scale_, &multiplier_->multiplier_, &multiplier_->left_shift_,
&multiplier_->right_shift_);
QuantizeRoundParameterWithDoublePrecision(quant_in_->scale_ / quant_out_->scale_, &multiplier_->multiplier_,
&multiplier_->left_shift_, &multiplier_->right_shift_);
if (!InferShapeDone()) { if (!InferShapeDone()) {
return RET_OK; return RET_OK;
} }


+ 3
- 0
mindspore/lite/src/tensor.h View File

@@ -38,6 +38,9 @@ struct QuantArg {
bool inited; bool inited;
std::vector<float> clusters{}; std::vector<float> clusters{};
int bitNum; int bitNum;
int roundType;
int multiplier;
int dstDtype;
}; };


class Tensor : public mindspore::tensor::MSTensor { class Tensor : public mindspore::tensor::MSTensor {


+ 1
- 1
mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc View File

@@ -118,7 +118,7 @@ TEST_F(TestMatmulInt8, simple) {
int a_sums[ROW4] = {0}; int a_sums[ROW4] = {0};
int bias[COL4] = {0}; int bias[COL4] = {0};
int multiplier, ls, rs; int multiplier, ls, rs;
QuantizeRoundParameter(1.0f, &multiplier, &ls, &rs);
QuantizeRoundParameterWithDoublePrecision(1.0f, &multiplier, &ls, &rs);
#ifdef ENABLE_ARM64 #ifdef ENABLE_ARM64
MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, &multiplier, &ls, MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, &multiplier, &ls,
&rs, ROW, COL, COL, false); &rs, ROW, COL, COL, false);


+ 5
- 0
mindspore/lite/tools/anf_exporter/anf_exporter.cc View File

@@ -121,6 +121,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
std::make_unique<schema::QuantParamT>(input_quant_param); std::make_unique<schema::QuantParamT>(input_quant_param);
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
<< " zp: " << input_quant_param_ptr->zeroPoint; << " zp: " << input_quant_param_ptr->zeroPoint;
input_quant_param_ptr->dstDtype = tensor_input->dataType;
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
} }
} }
@@ -151,6 +152,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
std::make_unique<schema::QuantParamT>(channel_quant_param); std::make_unique<schema::QuantParamT>(channel_quant_param);
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
<< " zp: " << output_quant_param_ptr->zeroPoint; << " zp: " << output_quant_param_ptr->zeroPoint;
output_quant_param_ptr->dstDtype = output_tensor->dataType;
output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr)); output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr));
} }
} }
@@ -258,6 +260,9 @@ int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::uniqu
auto subgraph_name = func_graph->get_attr("graph_name"); auto subgraph_name = func_graph->get_attr("graph_name");
MS_ASSERT(subgraph_name != nullptr); MS_ASSERT(subgraph_name != nullptr);
sub_graphT->name = GetValue<std::string>(subgraph_name); sub_graphT->name = GetValue<std::string>(subgraph_name);
auto fmk = func_graph->get_attr("fmk");
MS_ASSERT(fmk != nullptr);
meta_graphT->fmkType = GetValue<int>(fmk);


auto cnodes = func_graph->GetOrderedCnodes(); auto cnodes = func_graph->GetOrderedCnodes();
for (const auto &cnode : cnodes) { for (const auto &cnode : cnodes) {


+ 9
- 1
mindspore/lite/tools/common/graph_util.cc View File

@@ -448,6 +448,8 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
toAddTensor->dataType = prim->dstT; toAddTensor->dataType = prim->dstT;
if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) { if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) {
preTensor->quantParams.front()->zeroPoint += 128; preTensor->quantParams.front()->zeroPoint += 128;
} else if (prim->srcT == TypeId::kNumberTypeInt8 && prim->dstT == TypeId::kNumberTypeUInt8) {
toAddTensor->quantParams.front()->zeroPoint += 128;
} }
} }
graphT->allTensors.emplace_back(std::move(toAddTensor)); graphT->allTensors.emplace_back(std::move(toAddTensor));
@@ -491,6 +493,8 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
toAddTensor->dataType = prim->dstT; toAddTensor->dataType = prim->dstT;
if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) { if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) {
preTensor->quantParams.front()->zeroPoint += 128; preTensor->quantParams.front()->zeroPoint += 128;
} else if (prim->srcT == TypeId::kNumberTypeInt8 && prim->dstT == TypeId::kNumberTypeUInt8) {
toAddTensor->quantParams.front()->zeroPoint += 128;
} }
} }
graphT->allTensors.emplace_back(std::move(toAddTensor)); graphT->allTensors.emplace_back(std::move(toAddTensor));
@@ -552,8 +556,10 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
MS_ASSERT(prim != nullptr); MS_ASSERT(prim != nullptr);
postTensor->dataType = prim->srcT; postTensor->dataType = prim->srcT;
toAddTensor->dataType = prim->dstT; toAddTensor->dataType = prim->dstT;
if (prim->dstT == TypeId::kNumberTypeUInt8 && prim->srcT == TypeId::kNumberTypeInt8) {
if (prim->srcT == TypeId::kNumberTypeInt8 && prim->dstT == TypeId::kNumberTypeUInt8) {
toAddTensor->quantParams.front()->zeroPoint += 128; toAddTensor->quantParams.front()->zeroPoint += 128;
} else if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) {
postTensor->quantParams.front()->zeroPoint += 128;
} }
} }
graphT->allTensors.emplace_back(std::move(toAddTensor)); graphT->allTensors.emplace_back(std::move(toAddTensor));
@@ -624,6 +630,8 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
toAddTensor->dataType = prim->dstT; toAddTensor->dataType = prim->dstT;
if (prim->dstT == TypeId::kNumberTypeUInt8 && prim->srcT == TypeId::kNumberTypeInt8) { if (prim->dstT == TypeId::kNumberTypeUInt8 && prim->srcT == TypeId::kNumberTypeInt8) {
toAddTensor->quantParams.front()->zeroPoint += 128; toAddTensor->quantParams.front()->zeroPoint += 128;
} else if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) {
postTensor->quantParams.front()->zeroPoint += 128;
} }
} }
graphT->allTensors.emplace_back(std::move(toAddTensor)); graphT->allTensors.emplace_back(std::move(toAddTensor));


+ 2
- 0
mindspore/lite/tools/common/tensor_util.cc View File

@@ -38,6 +38,8 @@ std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schem
dstQuantParam->max = srcQuantParam->max; dstQuantParam->max = srcQuantParam->max;
dstQuantParam->narrowRange = srcQuantParam->narrowRange; dstQuantParam->narrowRange = srcQuantParam->narrowRange;
dstQuantParam->numBits = srcQuantParam->numBits; dstQuantParam->numBits = srcQuantParam->numBits;
dstQuantParam->dstDtype = srcQuantParam->dstDtype;
dstQuantParam->multiplier = srcQuantParam->multiplier;
return dstQuantParam; return dstQuantParam;
} }




+ 1
- 0
mindspore/lite/tools/converter/converter.cc View File

@@ -71,6 +71,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
return nullptr; return nullptr;
} }
graph->set_attr("graph_name", MakeValue("main_graph")); graph->set_attr("graph_name", MakeValue("main_graph"));
graph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_MS)));
} else { } else {
MS_ASSERT(nullptr != modelParser); MS_ASSERT(nullptr != modelParser);
const std::string modelFile = flag->modelFile; const std::string modelFile = flag->modelFile;


+ 1
- 1
mindspore/lite/tools/converter/converter_flags.cc View File

@@ -158,7 +158,7 @@ int Flags::Init(int argc, const char **argv) {
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }


if (this->trainModel == true) {
if (this->trainModel) {
if (this->fmk != FmkType_MS) { if (this->fmk != FmkType_MS) {
std::cerr << "INPUT ILLEGAL: train model convertor supporting only MINDIR format"; std::cerr << "INPUT ILLEGAL: train model convertor supporting only MINDIR format";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;


+ 8
- 1
mindspore/lite/tools/converter/converter_flags.h View File

@@ -30,7 +30,14 @@ using mindspore::schema::QuantType_PostTraining;
using mindspore::schema::QuantType_QUANT_NONE; using mindspore::schema::QuantType_QUANT_NONE;
using mindspore::schema::QuantType_WeightQuant; using mindspore::schema::QuantType_WeightQuant;
namespace converter { namespace converter {
enum FmkType { FmkType_TF = 0, FmkType_CAFFE = 1, FmkType_ONNX = 2, FmkType_MS = 3, FmkType_TFLITE = 4 };
enum FmkType {
FmkType_TF = 0,
FmkType_CAFFE = 1,
FmkType_ONNX = 2,
FmkType_MS = 3,
FmkType_TFLITE = 4,
FmkType_ONNX_LOW_VERSION = 5
};


class Flags : public virtual mindspore::lite::FlagParser { class Flags : public virtual mindspore::lite::FlagParser {
public: public:


+ 1
- 1
mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc View File

@@ -161,7 +161,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
if (postTensor->dataType != TypeId::kNumberTypeInt8) { if (postTensor->dataType != TypeId::kNumberTypeInt8) {
continue; continue;
} }
iter = InsertDTypeTransNode(graph, iter, kAfter, i, kNumberTypeFloat32, kNumberTypeInt8, &status);
iter = InsertDTypeTransNode(graph, iter, kAfter, i, kNumberTypeFloat32, kNumberTypeUInt8, &status);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed";
return RET_ERROR; return RET_ERROR;


+ 0
- 1
mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.cc View File

@@ -25,7 +25,6 @@ STATUS SetUnusedQuantParamToDefaultPass::Run(schema::MetaGraphT *graph) {
quant_param->min = 0; quant_param->min = 0;
quant_param->max = 0; quant_param->max = 0;
quant_param->narrowRange = true; quant_param->narrowRange = true;
quant_param->dstDtype = TypeId::kNumberTypeFloat32;
} }
} }
return RET_OK; return RET_OK;


+ 4
- 3
mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc View File

@@ -44,7 +44,7 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) {
} }
} }
} }
int index = -1;
unsigned int index = -1;
for (auto &tensor : graph->allTensors) { for (auto &tensor : graph->allTensors) {
index++; index++;
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
@@ -59,7 +59,8 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) {
auto &quantParam = tensor->quantParams.front(); auto &quantParam = tensor->quantParams.front();
size_t wShapeSize = tensor->data.empty() ? 0 : GetShapeSize(*(tensor.get())); size_t wShapeSize = tensor->data.empty() ? 0 : GetShapeSize(*(tensor.get()));
void *oriWeightData = tensor->data.data(); void *oriWeightData = tensor->data.data();
if (quantParam->dstDtype == TypeId::kNumberTypeInt8) {
if (quantParam->dstDtype == TypeId::kNumberTypeUInt8 || quantParam->dstDtype == TypeId::kNumberTypeFloat32 ||
quantParam->dstDtype == TypeId::kNumberTypeFloat) {
std::vector<int8_t> qDatas(wShapeSize); std::vector<int8_t> qDatas(wShapeSize);
auto weightQauntParam = GetTensorQuantParam(tensor); auto weightQauntParam = GetTensorQuantParam(tensor);
if (tensor->dataType == TypeId::kNumberTypeFloat || if (tensor->dataType == TypeId::kNumberTypeFloat ||
@@ -71,7 +72,7 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) {
for (size_t j = 0; j < wShapeSize; j++) { for (size_t j = 0; j < wShapeSize; j++) {
qDatas[j] = quant::QuantizeData<int8_t>(weightData[j], weightQauntParam.get()); qDatas[j] = quant::QuantizeData<int8_t>(weightData[j], weightQauntParam.get());
} }
} else { // tflite awareing quant
} else { // convert uint8 to int8
auto *weightData = static_cast<uint8_t *>(oriWeightData); auto *weightData = static_cast<uint8_t *>(oriWeightData);
for (size_t j = 0; j < wShapeSize; j++) { for (size_t j = 0; j < wShapeSize; j++) {
qDatas[j] = (int32_t)weightData[j] - 128; qDatas[j] = (int32_t)weightData[j] - 128;


+ 1
- 0
mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc View File

@@ -55,6 +55,7 @@ FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::s
return nullptr; return nullptr;
} }
func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph")); func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph"));
func_graph_ptr_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_CAFFE)));
return func_graph_ptr_; return func_graph_ptr_;
} }




+ 9
- 1
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc View File

@@ -40,6 +40,7 @@ std::set<std::string> SPECIAL_NODE = {"Gemm"};
FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file, FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) { const QuantType &quant_type) {
NoSupportOp::GetInstance()->SetFmkType("ONNX"); NoSupportOp::GetInstance()->SetFmkType("ONNX");
anf_root_graph_ = std::make_shared<FuncGraph>();
auto status = InitOriginModel(model_file); auto status = InitOriginModel(model_file);
if (RET_OK != status) { if (RET_OK != status) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@@ -47,7 +48,6 @@ FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::st
return nullptr; return nullptr;
} }


anf_root_graph_ = std::make_shared<FuncGraph>();
status = ConvertOnnxGraph(onnx_root_graph_, anf_root_graph_, &anf_nodes_map_, {}, "root_node"); status = ConvertOnnxGraph(onnx_root_graph_, anf_root_graph_, &anf_nodes_map_, {}, "root_node");
if (RET_OK != status) { if (RET_OK != status) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@@ -77,6 +77,11 @@ STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
} }
OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version()); OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version());
onnx_root_graph_ = onnx_model_.graph(); onnx_root_graph_ = onnx_model_.graph();
if (OnnxNodeParser::opset_version() > 15) {
anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
} else {
anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX_LOW_VERSION)));
}
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph, STATUS OnnxModelParser::ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
@@ -614,6 +619,9 @@ STATUS OnnxModelParser::SetTensorQuantParamFromNode(const std::string &tensor_na
std::vector<QuantParamT> *quant_params) { std::vector<QuantParamT> *quant_params) {
quant_params->clear(); quant_params->clear();
auto quant_param = std::make_unique<QuantParamT>(); auto quant_param = std::make_unique<QuantParamT>();
if (OnnxNodeParser::opset_version() <= 15) {
quant_param->multiplier = 0;
}
std::string quant_tensor_name = "scale_" + tensor_name; std::string quant_tensor_name = "scale_" + tensor_name;
auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), true); auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), true);
if (status != RET_OK) { if (status != RET_OK) {


+ 2
- 0
mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc View File

@@ -366,6 +366,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
return nullptr; return nullptr;
} }
anf_root_graph_->set_attr("graph_name", MakeValue("main_graph")); anf_root_graph_->set_attr("graph_name", MakeValue("main_graph"));
anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF)));


for (int i = 0; i < tf_root_graph_->node_size(); i++) { for (int i = 0; i < tf_root_graph_->node_size(); i++) {
auto &node_def = tf_root_graph_->node(i); auto &node_def = tf_root_graph_->node(i);
@@ -441,6 +442,7 @@ STATUS TFModelParser::ConvertSubgraph() {


FuncGraphPtr sub_func_graph = std::make_shared<FuncGraph>(); FuncGraphPtr sub_func_graph = std::make_shared<FuncGraph>();
sub_func_graph->set_attr("graph_name", MakeValue(sub_graph_name)); sub_func_graph->set_attr("graph_name", MakeValue(sub_graph_name));
sub_func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF)));
std::unordered_map<std::string, AnfNodePtr> anf_sub_node_map; std::unordered_map<std::string, AnfNodePtr> anf_sub_node_map;
// convert sub graph inputs // convert sub graph inputs
std::vector<ParameterPtr> sub_graph_inputs; std::vector<ParameterPtr> sub_graph_inputs;


+ 11
- 3
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc View File

@@ -55,6 +55,7 @@ FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::
return nullptr; return nullptr;
} }
func_graph_ = std::make_shared<FuncGraph>(); func_graph_ = std::make_shared<FuncGraph>();
func_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TFLITE)));


auto status = ConvertGraphInputs(); auto status = ConvertGraphInputs();
if (status != RET_OK) { if (status != RET_OK) {
@@ -183,7 +184,7 @@ STATUS TfliteModelParser::ConvertOps() {
} }


STATUS TfliteModelParser::SetTensorQuantParam(const tflite::TensorT *tflite_tensor, STATUS TfliteModelParser::SetTensorQuantParam(const tflite::TensorT *tflite_tensor,
std::vector<QuantParamT> *quant_params) {
std::vector<QuantParamT> *quant_params, int round_type) {
if (tflite_tensor == nullptr) { if (tflite_tensor == nullptr) {
MS_LOG(ERROR) << "tflite_tensor is null, set tensor quant params failed."; MS_LOG(ERROR) << "tflite_tensor is null, set tensor quant params failed.";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -221,6 +222,8 @@ STATUS TfliteModelParser::SetTensorQuantParam(const tflite::TensorT *tflite_tens
quant_param->max = tflite_tensor->quantization->max[i]; quant_param->max = tflite_tensor->quantization->max[i];
} }
quant_param->inited = true; quant_param->inited = true;
quant_param->roundType = round_type;
quant_param->multiplier = 1;
quant_params->emplace_back(*std::move(quant_param)); quant_params->emplace_back(*std::move(quant_param));
} }
return RET_OK; return RET_OK;
@@ -236,6 +239,11 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, lite
MS_LOG(ERROR) << "primitive_c is null, get quant params failed."; MS_LOG(ERROR) << "primitive_c is null, get quant params failed.";
return RET_NULL_PTR; return RET_NULL_PTR;
} }

int round_type = 1;
if (primitive_c->primitiveT()->value.type == PrimitiveType_Conv2D) {
round_type = 2;
}
const auto &tflite_subgraph = tflite_model_->subgraphs.front(); const auto &tflite_subgraph = tflite_model_->subgraphs.front();
for (auto input_idx : op->inputs) { for (auto input_idx : op->inputs) {
if (input_idx < 0) { if (input_idx < 0) {
@@ -243,7 +251,7 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, lite
} }
const auto &input_tensor = tflite_subgraph->tensors[input_idx]; const auto &input_tensor = tflite_subgraph->tensors[input_idx];
std::vector<schema::QuantParamT> quant_params; std::vector<schema::QuantParamT> quant_params;
auto status = SetTensorQuantParam(input_tensor.get(), &quant_params);
auto status = SetTensorQuantParam(input_tensor.get(), &quant_params, round_type);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "set input tensor quant param failed."; MS_LOG(ERROR) << "set input tensor quant param failed.";
return status; return status;
@@ -256,7 +264,7 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, lite
} }
const auto &output_tensor = tflite_subgraph->tensors.at(output_idx); const auto &output_tensor = tflite_subgraph->tensors.at(output_idx);
std::vector<schema::QuantParamT> quant_params; std::vector<schema::QuantParamT> quant_params;
auto status = SetTensorQuantParam(output_tensor.get(), &quant_params);
auto status = SetTensorQuantParam(output_tensor.get(), &quant_params, round_type);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "set output tensor quant param failed."; MS_LOG(ERROR) << "set output tensor quant param failed.";
return status; return status;


+ 2
- 1
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h View File

@@ -48,7 +48,8 @@ class TfliteModelParser : public ModelParser {
STATUS ConvertOps(); STATUS ConvertOps();
STATUS ConvertGraphInputs(); STATUS ConvertGraphInputs();
STATUS ConvertGraphOutputs(); STATUS ConvertGraphOutputs();
static STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector<QuantParamT> *quant_params);
static STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector<QuantParamT> *quant_params,
int round_type = 1);
}; };
} // namespace mindspore::lite } // namespace mindspore::lite
#endif // LITE_TFLITE_MODEL_PARSER_H #endif // LITE_TFLITE_MODEL_PARSER_H

+ 4
- 0
mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc View File

@@ -595,6 +595,8 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, stru
quant_param.numBits = bit_num; quant_param.numBits = bit_num;
quant_param.narrowRange = false; quant_param.narrowRange = false;
quant_param.inited = true; quant_param.inited = true;
quant_param.roundType = 1;
quant_param.multiplier = 1;
std::vector<schema::QuantParamT> quant_params = {quant_param}; std::vector<schema::QuantParamT> quant_params = {quant_param};
lite_primitive->AddInputQuantParam(quant_params); lite_primitive->AddInputQuantParam(quant_params);
return RET_OK; return RET_OK;
@@ -612,6 +614,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
quant_param.numBits = bit_num; quant_param.numBits = bit_num;
quant_param.narrowRange = false; quant_param.narrowRange = false;
quant_param.inited = true; quant_param.inited = true;
quant_param.roundType = 1;
quant_param.multiplier = 1;
std::vector<schema::QuantParamT> quant_params = {quant_param}; std::vector<schema::QuantParamT> quant_params = {quant_param};
lite_primitive->AddOutputQuantParam(quant_params); lite_primitive->AddOutputQuantParam(quant_params);
return RET_OK; return RET_OK;


Loading…
Cancel
Save