From 917adc6baf8d141d8ff77b7a5d8318bdc6c8568f Mon Sep 17 00:00:00 2001 From: zhanyuan Date: Thu, 24 Sep 2020 15:31:02 +0800 Subject: [PATCH] Fix the bug of int8 matmul weight tensor init --- mindspore/lite/nnacl/int8/matmul_int8.c | 15 +++-- mindspore/lite/nnacl/int8/matmul_int8.h | 2 +- .../kernel/arm/int8/fullconnection_int8.cc | 23 ++++--- .../runtime/kernel/arm/int8/matmul_int8.cc | 62 +++++++++---------- .../kernel/arm/int8/matmul_int8_tests.cc | 2 +- 5 files changed, 56 insertions(+), 48 deletions(-) diff --git a/mindspore/lite/nnacl/int8/matmul_int8.c b/mindspore/lite/nnacl/int8/matmul_int8.c index 7f09525e15..e7f60a39d9 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.c +++ b/mindspore/lite/nnacl/int8/matmul_int8.c @@ -268,13 +268,18 @@ void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, return; } -void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16) { +void RowMajor2Col16x4MajorInt8(int8_t *src, int row, int col, int8_t *dst) { + int row_16 = UP_ROUND(row, C16NUM); int stride = sizeof(int8_t) * 16 * 4; - for (int r = 0; r < row; ++r) { + for (int r = 0; r < row_16; ++r) { for (int c = 0; c < col; ++c) { - int stride_n = c / 4 * (row_16 / 16) + r / 16; - int src_idx = r * col + c; - dst[stride * stride_n + c % 4 * 16 + r % 16] = src[src_idx]; + int stride_idx = c / 4 * (row_16 / 16) + r / 16; + if (r >= row) { + dst[stride * stride_idx + c % 4 * 16 + r % 16] = 0; + } else { + int src_idx = r * col + c; + dst[stride * stride_idx + c % 4 * 16 + r % 16] = src[src_idx]; + } } } } diff --git a/mindspore/lite/nnacl/int8/matmul_int8.h b/mindspore/lite/nnacl/int8/matmul_int8.h index 47e51a86ee..7aa1285dbe 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.h +++ b/mindspore/lite/nnacl/int8/matmul_int8.h @@ -33,7 +33,7 @@ void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, bool per_channel); void RowMajor2Row16x4MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); -void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16); +void RowMajor2Col16x4MajorInt8(int8_t *src, int row, int col, int8_t *dst); void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order); void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst, DataOrder order); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc index bdaf23c972..82e96104ae 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc @@ -58,13 +58,11 @@ int FullconnectionInt8CPUKernel::ReSize() { weight_bias_sums_ = reinterpret_cast(ctx_->allocator->Malloc(c4_ * sizeof(int))); if (!weight_bias_sums_) return RET_MEMORY_FAILED; memset(weight_bias_sums_, 0, c4_ * sizeof(int)); - auto weight_data = reinterpret_cast(in_tensors_[1]->MutableData()); - RowMajor2Row16x4MajorInt8(weight_data, b_c16x4_ptr_, fc_param_->col_, fc_param_->deep_); if (in_tensors_.size() == 3) { auto bias_len = fc_param_->col_8_ * sizeof(int); bias_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(bias_len)); if (!bias_ptr_) return RET_MEMORY_FAILED; - memcpy(bias_ptr_, in_tensors_[2]->MutableData(), bias_len); + memcpy(bias_ptr_, in_tensors_[2]->data_c(), bias_len); } else { bias_ptr_ = NULL; } @@ -91,8 +89,13 @@ int FullconnectionInt8CPUKernel::ReSize() { CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6, quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_min, &quant_params_.out_act_max); - CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, - bias_ptr_, weight_bias_sums_, ColMajor); + fc_param_->b_const_ = (in_tensors_[1]->data_c() != nullptr); + if (fc_param_->b_const_) { + auto weight_data = reinterpret_cast(in_tensors_[1]->data_c()); + RowMajor2Row16x4MajorInt8(weight_data, b_c16x4_ptr_, fc_param_->col_, fc_param_->deep_); + CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_params_.input.zp_, + quant_params_.weight.zp_, bias_ptr_, weight_bias_sums_, ColMajor); + } return RET_OK; } @@ -106,7 +109,7 @@ int FullconnectionInt8CPUKernel::RunImpl(int task_id) { auto &p = fc_param_; auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * C4NUM * d16_; auto cur_bias = weight_bias_sums_ + task_id * thread_stride_ * C4NUM; - auto output_ptr = reinterpret_cast(out_tensors_[0]->MutableData()); + auto output_ptr = reinterpret_cast(out_tensors_[0]->data_c()); auto cur_c = output_ptr + task_id * thread_stride_ * C4NUM; #ifdef ENABLE_ARM64 MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, q.out_act_min, @@ -136,9 +139,15 @@ int FullconnectionInt8CPUKernel::Run() { MS_LOG(ERROR) << "Prepare failed."; return RET_ERROR; } - auto input_ptr = reinterpret_cast(in_tensors_[0]->MutableData()); + auto input_ptr = reinterpret_cast(in_tensors_[0]->data_c()); RowMajor2Row16x4MajorInt8(input_ptr, a_r4x16_ptr_, fc_param_->row_, fc_param_->deep_); CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor); + if (!fc_param_->b_const_) { + auto weight_data = reinterpret_cast(in_tensors_[1]->data_c()); + RowMajor2Row16x4MajorInt8(weight_data, b_c16x4_ptr_, fc_param_->col_, fc_param_->deep_); + CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_params_.input.zp_, + quant_params_.weight.zp_, bias_ptr_, weight_bias_sums_, ColMajor); + } ParallelLaunch(this->context_->thread_pool_, FcInt8Run, this, thread_count_); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc index 100e1773a5..3952c0798b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc @@ -55,6 +55,14 @@ int MatmulInt8CPUKernel::ReSize() { input_sums_ = reinterpret_cast(ctx_->allocator->Malloc(params_->row_4_ * sizeof(int))); if (!input_sums_) return RET_MEMORY_FAILED; memset(input_sums_, 0, params_->row_4_ * sizeof(int)); + b_c16x4_batch_ = reinterpret_cast( + ctx_->allocator->Malloc(params_->batch * params_->col_4_ * params_->deep_16_ * sizeof(int8_t))); + if (!b_c16x4_batch_) return RET_MEMORY_FAILED; + memset(b_c16x4_batch_, 0, params_->batch * params_->col_4_ * params_->deep_16_ * sizeof(int8_t)); + weight_bias_sums_batch_ = + reinterpret_cast(ctx_->allocator->Malloc(params_->batch * params_->col_4_ * sizeof(int))); + if (!weight_bias_sums_batch_) return RET_MEMORY_FAILED; + memset(weight_bias_sums_batch_, 0, params_->batch * params_->col_4_ * sizeof(int)); if (in_tensors_.size() == 3) { auto bias_size = params_->col_4_ * sizeof(int); bias_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(bias_size)); @@ -63,32 +71,6 @@ int MatmulInt8CPUKernel::ReSize() { } else { bias_ptr_ = NULL; } - - params_->b_const_ = (in_tensors_[1]->data_c() != nullptr); - if (params_->b_const_) { - b_c16x4_batch_ = reinterpret_cast( - ctx_->allocator->Malloc(params_->batch * params_->col_4_ * params_->deep_16_ * sizeof(int8_t))); - if (!b_c16x4_batch_) return RET_MEMORY_FAILED; - weight_bias_sums_batch_ = - reinterpret_cast(ctx_->allocator->Malloc(params_->batch * params_->col_4_ * sizeof(int))); - if (!weight_bias_sums_batch_) return RET_MEMORY_FAILED; - auto b_ptr = reinterpret_cast(in_tensors_[1]->data_c()); - - for (int i = 0; i < params_->batch; ++i) { - auto cur_b = b_ptr + i * params_->deep_ * params_->col_; - auto cur_b_pack = b_c16x4_batch_ + i * params_->col_4_ * params_->deep_16_; - auto cur_sums = weight_bias_sums_batch_ + i * params_->col_4_; - if (params_->b_transpose_) { - RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, params_->col_, params_->deep_); - CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, - bias_ptr_, cur_sums, ColMajor); - } else { - RowMajor2Col16x4Major(cur_b, params_->deep_, params_->col_, cur_b_pack, params_->deep_16_); - CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, - bias_ptr_, cur_sums, RowMajor); - } - } - } thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_4_, 4)); thread_stride_ = UP_DIV(UP_DIV(params_->col_4_, 4), thread_count_); @@ -108,6 +90,24 @@ int MatmulInt8CPUKernel::ReSize() { quant_params_.output.zp_ = params.front().zeroPoint; quant_params_.output.scale_ = params.front().scale; + params_->b_const_ = (in_tensors_[1]->data_c() != nullptr); + if (params_->b_const_) { + auto b_ptr = reinterpret_cast(in_tensors_[1]->data_c()); + for (int i = 0; i < params_->batch; ++i) { + auto cur_b = b_ptr + i * params_->deep_ * params_->col_; + auto cur_b_pack = b_c16x4_batch_ + i * params_->col_4_ * params_->deep_16_; + auto cur_sums = weight_bias_sums_batch_ + i * params_->col_4_; + if (params_->b_transpose_) { + RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, params_->col_, params_->deep_); + CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, + bias_ptr_, cur_sums, ColMajor); + } else { + RowMajor2Col16x4MajorInt8(cur_b, params_->deep_, params_->col_, cur_b_pack); + CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, + bias_ptr_, cur_sums, RowMajor); + } + } + } double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_; QuantizeRoundParameter(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift, &quant_params_.right_shift); @@ -161,12 +161,6 @@ int MatmulInt8CPUKernel::Run() { auto c_stride = params_->row_ * params_->col_; if (!params_->b_const_) { - b_c16x4_batch_ = reinterpret_cast( - ctx_->allocator->Malloc(params_->batch * params_->col_4_ * params_->deep_16_ * sizeof(int8_t))); - if (!b_c16x4_batch_) return RET_MEMORY_FAILED; - weight_bias_sums_batch_ = - reinterpret_cast(ctx_->allocator->Malloc(params_->batch * params_->col_4_ * sizeof(int))); - if (!weight_bias_sums_batch_) return RET_MEMORY_FAILED; auto b_ptr = reinterpret_cast(in_tensors_[1]->data_c()); for (int i = 0; i < params_->batch; ++i) { auto cur_b = b_ptr + i * b_stride; @@ -177,7 +171,7 @@ int MatmulInt8CPUKernel::Run() { CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, bias_ptr_, cur_sums, ColMajor); } else { - RowMajor2Col16x4Major(cur_b, params_->deep_, params_->col_, cur_b_pack, params_->deep_16_); + RowMajor2Col16x4MajorInt8(cur_b, params_->deep_, params_->col_, cur_b_pack); CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, bias_ptr_, cur_sums, RowMajor); } @@ -187,7 +181,7 @@ int MatmulInt8CPUKernel::Run() { for (int i = 0; i < params_->batch; ++i) { auto cur_a_ptr = a_ptr + i * a_stride; if (params_->a_transpose_) { - RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, params_->deep_16_); + RowMajor2Col16x4MajorInt8(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_); CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, ColMajor); } else { RowMajor2Row16x4MajorInt8(cur_a_ptr, a_r4x16_ptr_, params_->row_, params_->deep_); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc index 52220834c6..6ad8b03222 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc @@ -114,7 +114,7 @@ TEST_F(TestMatmulInt8, simple) { int8_t *b_c16x4 = new int8_t[COL4 * DEPTH16]; memset(b_c16x4, 0, COL4 * DEPTH16); RowMajor2Row16x4MajorInt8(a, a_r4x16, ROW, DEPTH); - RowMajor2Col16x4Major(b, DEPTH, COL, b_c16x4, DEPTH16); + RowMajor2Col16x4MajorInt8(b, DEPTH, COL, b_c16x4); int a_sums[ROW4] = {0}; int bias[COL4] = {0}; int multiplier, ls, rs;