Merge pull request !5324 from zhanyuan/devtags/v1.0.0
| @@ -225,12 +225,15 @@ void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_1 | |||||
| } | } | ||||
| // dst: weight_zp * input_row_sums | // dst: weight_zp * input_row_sums | ||||
| void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst) { | |||||
| void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order) { | |||||
| for (int r = 0; r < row; ++r) { | for (int r = 0; r < row; ++r) { | ||||
| int sum = 0; | int sum = 0; | ||||
| for (int c = 0; c < col; ++c) { | for (int c = 0; c < col; ++c) { | ||||
| int src_idx = r * col + c; | |||||
| sum += input[src_idx]; | |||||
| if (order == RowMajor) { | |||||
| sum += input[r * col + c]; | |||||
| } else { | |||||
| sum += input[c * row + r]; | |||||
| } | |||||
| } | } | ||||
| sum *= weight_zp; | sum *= weight_zp; | ||||
| dst[r] = sum; | dst[r] = sum; | ||||
| @@ -238,12 +241,16 @@ void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst) { | |||||
| } | } | ||||
| // dst: bias + depth*input_zp*weight_zp - input_zp*weight_col_sums | // dst: bias + depth*input_zp*weight_zp - input_zp*weight_col_sums | ||||
| void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst) { | |||||
| void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst, | |||||
| DataOrder order) { | |||||
| for (int c = 0; c < col; ++c) { | for (int c = 0; c < col; ++c) { | ||||
| int sum = 0; | int sum = 0; | ||||
| for (int r = 0; r < row; ++r) { | for (int r = 0; r < row; ++r) { | ||||
| int src_idx = r * col + c; | |||||
| sum += weight[src_idx]; | |||||
| if (order == RowMajor) { | |||||
| sum += weight[r * col + c]; | |||||
| } else { | |||||
| sum += weight[c * row + r]; | |||||
| } | |||||
| } | } | ||||
| dst[c] = row * input_zp * weight_zp - input_zp * sum; | dst[c] = row * input_zp * weight_zp - input_zp * sum; | ||||
| if (bias) { | if (bias) { | ||||
| @@ -37,8 +37,9 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col); | |||||
| void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16); | void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16); | ||||
| void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16); | void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16); | ||||
| void CalcInputSums(int8_t *a, int row, int col, int b_zp, int *dst); | |||||
| void CalcWeightBiasSums(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst); | |||||
| void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order); | |||||
| void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst, | |||||
| DataOrder order); | |||||
| void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min, | void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min, | ||||
| int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16, | int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16, | ||||
| int stride); | int stride); | ||||
| @@ -55,6 +55,11 @@ typedef enum LiteDataType { | |||||
| kDataTypeInt8, | kDataTypeInt8, | ||||
| } LiteDataType; | } LiteDataType; | ||||
| typedef enum DataOrder { | |||||
| RowMajor, | |||||
| ColMajor, | |||||
| } DataOrder; | |||||
| typedef struct OpParameter { | typedef struct OpParameter { | ||||
| char name_[100]; | char name_[100]; | ||||
| int type_; | int type_; | ||||
| @@ -90,7 +90,7 @@ int FullconnectionInt8CPUKernel::ReSize() { | |||||
| quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_min, | quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_min, | ||||
| &quant_params_.out_act_max); | &quant_params_.out_act_max); | ||||
| CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, | CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, | ||||
| bias_ptr_, weight_bias_sums_); | |||||
| bias_ptr_, weight_bias_sums_, ColMajor); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -136,7 +136,7 @@ int FullconnectionInt8CPUKernel::Run() { | |||||
| } | } | ||||
| auto input_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->Data()); | auto input_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->Data()); | ||||
| RowMajor2Row4x16Major(input_ptr, fc_param_->row_, fc_param_->deep_, a_r4x16_ptr_, d16_); | RowMajor2Row4x16Major(input_ptr, fc_param_->row_, fc_param_->deep_, a_r4x16_ptr_, d16_); | ||||
| CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, quant_params_.weight.zp_, input_sums_); | |||||
| CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor); | |||||
| ParallelLaunch(THREAD_POOL_DEFAULT, FcInt8Run, this, thread_count_); | ParallelLaunch(THREAD_POOL_DEFAULT, FcInt8Run, this, thread_count_); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -140,18 +140,21 @@ int MatmulInt8CPUKernel::Run() { | |||||
| if (params_->a_transpose_) { | if (params_->a_transpose_) { | ||||
| RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, d16_); | RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, d16_); | ||||
| CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, ColMajor); | |||||
| } else { | } else { | ||||
| RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4x16_ptr_, d16_); | RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4x16_ptr_, d16_); | ||||
| CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor); | |||||
| } | } | ||||
| if (params_->b_transpose_) { | if (params_->b_transpose_) { | ||||
| RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c16x4_ptr_, d16_); | RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c16x4_ptr_, d16_); | ||||
| CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, | |||||
| NULL, weight_bias_sums_, ColMajor); | |||||
| } else { | } else { | ||||
| RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c16x4_ptr_, d16_); | RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c16x4_ptr_, d16_); | ||||
| CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, | |||||
| NULL, weight_bias_sums_, RowMajor); | |||||
| } | } | ||||
| c_ptr_ = c_ptr + i * c_stride; | c_ptr_ = c_ptr + i * c_stride; | ||||
| auto &q = quant_params_; | |||||
| CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, q.weight.zp_, input_sums_); | |||||
| CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, q.input.zp_, q.weight.zp_, NULL, weight_bias_sums_); | |||||
| ret = ParallelLaunch(THREAD_POOL_DEFAULT, MatmulInt8Run, this, thread_count_); | ret = ParallelLaunch(THREAD_POOL_DEFAULT, MatmulInt8Run, this, thread_count_); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "MatmulInt8Run error: [" << ret << "]"; | MS_LOG(ERROR) << "MatmulInt8Run error: [" << ret << "]"; | ||||
| @@ -226,7 +226,8 @@ STATUS AwareQuantizer::DoQuantize() { | |||||
| } | } | ||||
| STATUS status; | STATUS status; | ||||
| if (GetCNodeTType(*node) == schema::PrimitiveType_Conv2D || | if (GetCNodeTType(*node) == schema::PrimitiveType_Conv2D || | ||||
| GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D) { | |||||
| GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D || | |||||
| GetCNodeTType(*node) == schema::PrimitiveType_FullConnection) { | |||||
| auto inputIndexes = node->inputIndex; | auto inputIndexes = node->inputIndex; | ||||
| if (inputIndexes.size() < 2) { | if (inputIndexes.size() < 2) { | ||||
| MS_LOG(ERROR) << node->name.c_str() | MS_LOG(ERROR) << node->name.c_str() | ||||