From 203a4d2ea80d443b6d1150510a88b3a05827ceb4 Mon Sep 17 00:00:00 2001 From: zhanyuan Date: Thu, 27 Aug 2020 11:43:04 +0800 Subject: [PATCH] Fix the bug of matmul_int8's pre-process --- mindspore/lite/nnacl/int8/matmul_int8.c | 19 +++++++++++++------ mindspore/lite/nnacl/int8/matmul_int8.h | 5 +++-- mindspore/lite/nnacl/op_base.h | 5 +++++ .../kernel/arm/int8/fullconnection_int8.cc | 4 ++-- .../runtime/kernel/arm/int8/matmul_int8.cc | 9 ++++++--- .../converter/quantizer/aware_quantizer.cc | 3 ++- 6 files changed, 31 insertions(+), 14 deletions(-) diff --git a/mindspore/lite/nnacl/int8/matmul_int8.c b/mindspore/lite/nnacl/int8/matmul_int8.c index 31ac84a24b..13b8b8fdb9 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.c +++ b/mindspore/lite/nnacl/int8/matmul_int8.c @@ -225,12 +225,15 @@ void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_1 } // dst: weight_zp * input_row_sums -void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst) { +void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order) { for (int r = 0; r < row; ++r) { int sum = 0; for (int c = 0; c < col; ++c) { - int src_idx = r * col + c; - sum += input[src_idx]; + if (order == RowMajor) { + sum += input[r * col + c]; + } else { + sum += input[c * row + r]; + } } sum *= weight_zp; dst[r] = sum; @@ -238,12 +241,16 @@ void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst) { } // dst: bias + depth*input_zp*weight_zp - input_zp*weight_col_sums -void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst) { +void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst, + DataOrder order) { for (int c = 0; c < col; ++c) { int sum = 0; for (int r = 0; r < row; ++r) { - int src_idx = r * col + c; - sum += weight[src_idx]; + if (order == RowMajor) { + sum += weight[r * col + c]; + } else { + sum += weight[c * row + r]; + } } dst[c] = row * input_zp * weight_zp - input_zp * sum; if (bias) { diff --git a/mindspore/lite/nnacl/int8/matmul_int8.h b/mindspore/lite/nnacl/int8/matmul_int8.h index 04ff5972f3..c9e1d01873 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.h +++ b/mindspore/lite/nnacl/int8/matmul_int8.h @@ -37,8 +37,9 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col); void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16); void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16); -void CalcInputSums(int8_t *a, int row, int col, int b_zp, int *dst); -void CalcWeightBiasSums(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst); +void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order); +void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst, + DataOrder order); void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16, int stride); diff --git a/mindspore/lite/nnacl/op_base.h b/mindspore/lite/nnacl/op_base.h index 6154dedb74..e5bf293ed3 100644 --- a/mindspore/lite/nnacl/op_base.h +++ b/mindspore/lite/nnacl/op_base.h @@ -55,6 +55,11 @@ typedef enum LiteDataType { kDataTypeInt8, } LiteDataType; +typedef enum DataOrder { + RowMajor, + ColMajor, +} DataOrder; + typedef struct OpParameter { char name_[100]; int type_; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc index 48e4ffec66..e0ebe39441 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc @@ -90,7 +90,7 @@ int FullconnectionInt8CPUKernel::ReSize() { quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_min, &quant_params_.out_act_max); CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, - bias_ptr_, weight_bias_sums_); + bias_ptr_, weight_bias_sums_, ColMajor); return RET_OK; } @@ -136,7 +136,7 @@ int FullconnectionInt8CPUKernel::Run() { } auto input_ptr = reinterpret_cast(in_tensors_[0]->Data()); RowMajor2Row4x16Major(input_ptr, fc_param_->row_, fc_param_->deep_, a_r4x16_ptr_, d16_); - CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, quant_params_.weight.zp_, input_sums_); + CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor); ParallelLaunch(THREAD_POOL_DEFAULT, FcInt8Run, this, thread_count_); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc index aa93c9d4c4..9c3bafe389 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc @@ -140,18 +140,21 @@ int MatmulInt8CPUKernel::Run() { if (params_->a_transpose_) { RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, d16_); + CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, ColMajor); } else { RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4x16_ptr_, d16_); + CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor); } if (params_->b_transpose_) { RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c16x4_ptr_, d16_); + CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, + NULL, weight_bias_sums_, ColMajor); } else { RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c16x4_ptr_, d16_); + CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, + NULL, weight_bias_sums_, RowMajor); } c_ptr_ = c_ptr + i * c_stride; - auto &q = quant_params_; - CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, q.weight.zp_, input_sums_); - CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, q.input.zp_, q.weight.zp_, NULL, weight_bias_sums_); ret = ParallelLaunch(THREAD_POOL_DEFAULT, MatmulInt8Run, this, thread_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "MatmulInt8Run error: [" << ret << "]"; diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc index 3c4ff4dc2c..2a55aedf93 100644 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc @@ -226,7 +226,8 @@ STATUS AwareQuantizer::DoQuantize() { } STATUS status; if (GetCNodeTType(*node) == schema::PrimitiveType_Conv2D || - GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D) { + GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D || + GetCNodeTType(*node) == schema::PrimitiveType_FullConnection) { auto inputIndexes = node->inputIndex; if (inputIndexes.size() < 2) { MS_LOG(ERROR) << node->name.c_str()