Browse Source

!5324 Fix the bug of matmul_int8's pre-process

Merge pull request !5324 from zhanyuan/dev
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
c16566fafc
6 changed files with 31 additions and 14 deletions
  1. +13
    -6
      mindspore/lite/nnacl/int8/matmul_int8.c
  2. +3
    -2
      mindspore/lite/nnacl/int8/matmul_int8.h
  3. +5
    -0
      mindspore/lite/nnacl/op_base.h
  4. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc
  5. +6
    -3
      mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc
  6. +2
    -1
      mindspore/lite/tools/converter/quantizer/aware_quantizer.cc

+ 13
- 6
mindspore/lite/nnacl/int8/matmul_int8.c View File

@@ -225,12 +225,15 @@ void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_1
} }


// dst: weight_zp * input_row_sums // dst: weight_zp * input_row_sums
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst) {
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order) {
for (int r = 0; r < row; ++r) { for (int r = 0; r < row; ++r) {
int sum = 0; int sum = 0;
for (int c = 0; c < col; ++c) { for (int c = 0; c < col; ++c) {
int src_idx = r * col + c;
sum += input[src_idx];
if (order == RowMajor) {
sum += input[r * col + c];
} else {
sum += input[c * row + r];
}
} }
sum *= weight_zp; sum *= weight_zp;
dst[r] = sum; dst[r] = sum;
@@ -238,12 +241,16 @@ void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst) {
} }


// dst: bias + depth*input_zp*weight_zp - input_zp*weight_col_sums // dst: bias + depth*input_zp*weight_zp - input_zp*weight_col_sums
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst) {
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst,
DataOrder order) {
for (int c = 0; c < col; ++c) { for (int c = 0; c < col; ++c) {
int sum = 0; int sum = 0;
for (int r = 0; r < row; ++r) { for (int r = 0; r < row; ++r) {
int src_idx = r * col + c;
sum += weight[src_idx];
if (order == RowMajor) {
sum += weight[r * col + c];
} else {
sum += weight[c * row + r];
}
} }
dst[c] = row * input_zp * weight_zp - input_zp * sum; dst[c] = row * input_zp * weight_zp - input_zp * sum;
if (bias) { if (bias) {


+ 3
- 2
mindspore/lite/nnacl/int8/matmul_int8.h View File

@@ -37,8 +37,9 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col);


void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16); void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16);
void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16); void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16);
void CalcInputSums(int8_t *a, int row, int col, int b_zp, int *dst);
void CalcWeightBiasSums(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst);
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order);
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst,
DataOrder order);
void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min, void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min,
int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16, int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16,
int stride); int stride);


+ 5
- 0
mindspore/lite/nnacl/op_base.h View File

@@ -55,6 +55,11 @@ typedef enum LiteDataType {
kDataTypeInt8, kDataTypeInt8,
} LiteDataType; } LiteDataType;


typedef enum DataOrder {
RowMajor,
ColMajor,
} DataOrder;

typedef struct OpParameter { typedef struct OpParameter {
char name_[100]; char name_[100];
int type_; int type_;


+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc View File

@@ -90,7 +90,7 @@ int FullconnectionInt8CPUKernel::ReSize() {
quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_min, quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_min,
&quant_params_.out_act_max); &quant_params_.out_act_max);
CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
bias_ptr_, weight_bias_sums_);
bias_ptr_, weight_bias_sums_, ColMajor);
return RET_OK; return RET_OK;
} }


@@ -136,7 +136,7 @@ int FullconnectionInt8CPUKernel::Run() {
} }
auto input_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->Data()); auto input_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->Data());
RowMajor2Row4x16Major(input_ptr, fc_param_->row_, fc_param_->deep_, a_r4x16_ptr_, d16_); RowMajor2Row4x16Major(input_ptr, fc_param_->row_, fc_param_->deep_, a_r4x16_ptr_, d16_);
CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, quant_params_.weight.zp_, input_sums_);
CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor);
ParallelLaunch(THREAD_POOL_DEFAULT, FcInt8Run, this, thread_count_); ParallelLaunch(THREAD_POOL_DEFAULT, FcInt8Run, this, thread_count_);
return RET_OK; return RET_OK;
} }


+ 6
- 3
mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc View File

@@ -140,18 +140,21 @@ int MatmulInt8CPUKernel::Run() {


if (params_->a_transpose_) { if (params_->a_transpose_) {
RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, d16_); RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, d16_);
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, ColMajor);
} else { } else {
RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4x16_ptr_, d16_); RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4x16_ptr_, d16_);
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor);
} }
if (params_->b_transpose_) { if (params_->b_transpose_) {
RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c16x4_ptr_, d16_); RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c16x4_ptr_, d16_);
CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
NULL, weight_bias_sums_, ColMajor);
} else { } else {
RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c16x4_ptr_, d16_); RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c16x4_ptr_, d16_);
CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
NULL, weight_bias_sums_, RowMajor);
} }
c_ptr_ = c_ptr + i * c_stride; c_ptr_ = c_ptr + i * c_stride;
auto &q = quant_params_;
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, q.weight.zp_, input_sums_);
CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, q.input.zp_, q.weight.zp_, NULL, weight_bias_sums_);
ret = ParallelLaunch(THREAD_POOL_DEFAULT, MatmulInt8Run, this, thread_count_); ret = ParallelLaunch(THREAD_POOL_DEFAULT, MatmulInt8Run, this, thread_count_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "MatmulInt8Run error: [" << ret << "]"; MS_LOG(ERROR) << "MatmulInt8Run error: [" << ret << "]";


+ 2
- 1
mindspore/lite/tools/converter/quantizer/aware_quantizer.cc View File

@@ -226,7 +226,8 @@ STATUS AwareQuantizer::DoQuantize() {
} }
STATUS status; STATUS status;
if (GetCNodeTType(*node) == schema::PrimitiveType_Conv2D || if (GetCNodeTType(*node) == schema::PrimitiveType_Conv2D ||
GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D) {
GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D ||
GetCNodeTType(*node) == schema::PrimitiveType_FullConnection) {
auto inputIndexes = node->inputIndex; auto inputIndexes = node->inputIndex;
if (inputIndexes.size() < 2) { if (inputIndexes.size() < 2) {
MS_LOG(ERROR) << node->name.c_str() MS_LOG(ERROR) << node->name.c_str()


Loading…
Cancel
Save