diff --git a/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc index 78f2594cc8..6cbc991d71 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc @@ -34,12 +34,11 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector &in MS_ASSERT(desc.type == schema::PrimitiveType_Concat); auto *weight_tensor = inputs.at(kWeightIndex); - auto *restore_data = weight_tensor->MutableData(); - if (restore_data == nullptr) { - MS_LOG(ERROR) << "weight_tensor MutableData is nullptr."; - return nullptr; - } - if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + auto *restore_data = weight_tensor->data_c(); + auto is_const_quant_weight = + (restore_data != nullptr) && + (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant); + if (is_const_quant_weight) { auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); if (dequant_weight == nullptr) { MS_LOG(ERROR) << "dequant data is nullptr."; @@ -58,7 +57,7 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector &in } if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; - if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + if (is_const_quant_weight) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } @@ -69,14 +68,14 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector &in delete kernel; MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + if (is_const_quant_weight) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } return nullptr; } - if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + if (is_const_quant_weight) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc index c63e612ef5..100e1773a5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc @@ -45,26 +45,52 @@ int MatmulInt8CPUKernel::ReSize() { params_->row_ = o_shape[o_shape.size() - 2]; params_->col_ = o_shape[o_shape.size() - 1]; params_->deep_ = params_->a_transpose_ ? x_shape[x_shape.size() - 2] : x_shape[x_shape.size() - 1]; - params_->row_8_ = UP_ROUND(params_->row_, 8); - params_->col_8_ = UP_ROUND(params_->col_, 8); - - r4_ = UP_ROUND(params_->row_, 4); - c4_ = UP_ROUND(params_->col_, 4); - d16_ = UP_ROUND(params_->deep_, 16); - a_r4x16_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t))); + params_->row_4_ = UP_ROUND(params_->row_, 4); + params_->col_4_ = UP_ROUND(params_->col_, 4); + params_->deep_16_ = UP_ROUND(params_->deep_, 16); + a_r4x16_ptr_ = + reinterpret_cast(ctx_->allocator->Malloc(params_->row_4_ * params_->deep_16_ * sizeof(int8_t))); if (!a_r4x16_ptr_) return RET_MEMORY_FAILED; - memset(a_r4x16_ptr_, 0, r4_ * d16_ * sizeof(int8_t)); - b_c16x4_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t))); - if (!b_c16x4_ptr_) return RET_MEMORY_FAILED; - memset(b_c16x4_ptr_, 0, c4_ * d16_ * sizeof(int8_t)); - input_sums_ = reinterpret_cast(ctx_->allocator->Malloc(r4_ * sizeof(int))); + memset(a_r4x16_ptr_, 0, params_->row_4_ * params_->deep_16_ * sizeof(int8_t)); + input_sums_ = reinterpret_cast(ctx_->allocator->Malloc(params_->row_4_ * sizeof(int))); if (!input_sums_) return RET_MEMORY_FAILED; - memset(input_sums_, 0, r4_ * sizeof(int)); - weight_bias_sums_ = reinterpret_cast(ctx_->allocator->Malloc(c4_ * sizeof(int))); - if (!weight_bias_sums_) return RET_MEMORY_FAILED; - memset(weight_bias_sums_, 0, c4_ * sizeof(int)); - thread_count_ = MSMIN(thread_count_, UP_DIV(c4_, 4)); - thread_stride_ = UP_DIV(UP_DIV(c4_, 4), thread_count_); + memset(input_sums_, 0, params_->row_4_ * sizeof(int)); + if (in_tensors_.size() == 3) { + auto bias_size = params_->col_4_ * sizeof(int); + bias_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(bias_size)); + if (!bias_ptr_) return RET_MEMORY_FAILED; + memcpy(bias_ptr_, in_tensors_[2]->data_c(), bias_size); + } else { + bias_ptr_ = NULL; + } + + params_->b_const_ = (in_tensors_[1]->data_c() != nullptr); + if (params_->b_const_) { + b_c16x4_batch_ = reinterpret_cast( + ctx_->allocator->Malloc(params_->batch * params_->col_4_ * params_->deep_16_ * sizeof(int8_t))); + if (!b_c16x4_batch_) return RET_MEMORY_FAILED; + weight_bias_sums_batch_ = + reinterpret_cast(ctx_->allocator->Malloc(params_->batch * params_->col_4_ * sizeof(int))); + if (!weight_bias_sums_batch_) return RET_MEMORY_FAILED; + auto b_ptr = reinterpret_cast(in_tensors_[1]->data_c()); + + for (int i = 0; i < params_->batch; ++i) { + auto cur_b = b_ptr + i * params_->deep_ * params_->col_; + auto cur_b_pack = b_c16x4_batch_ + i * params_->col_4_ * params_->deep_16_; + auto cur_sums = weight_bias_sums_batch_ + i * params_->col_4_; + if (params_->b_transpose_) { + RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, params_->col_, params_->deep_); + CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, + bias_ptr_, cur_sums, ColMajor); + } else { + RowMajor2Col16x4Major(cur_b, params_->deep_, params_->col_, cur_b_pack, params_->deep_16_); + CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, + bias_ptr_, cur_sums, RowMajor); + } + } + } + thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_4_, 4)); + thread_stride_ = UP_DIV(UP_DIV(params_->col_4_, 4), thread_count_); auto input_tensor = in_tensors_[0]; auto params = input_tensor->GetQuantParams(); @@ -89,23 +115,24 @@ int MatmulInt8CPUKernel::ReSize() { } int MatmulInt8CPUKernel::RunImpl(int task_id) { - int cur_oc = MSMIN(thread_stride_, UP_DIV(c4_, 4) - task_id * thread_stride_); + int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_4_, 4) - task_id * thread_stride_); if (cur_oc <= 0) { return RET_OK; } int cur_oc_res = MSMIN(thread_stride_ * C4NUM, params_->col_ - task_id * thread_stride_ * C4NUM); - auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * 4 * d16_; + auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * 4 * params_->deep_16_; auto cur_bias = weight_bias_sums_ + task_id * thread_stride_ * 4; auto cur_c = c_ptr_ + task_id * thread_stride_ * 4; auto &p = quant_params_; #ifdef ENABLE_ARM64 - MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, INT8_MIN, INT8_MAX, - p.output.zp_, &p.quant_multiplier, &p.left_shift, &p.right_shift, params_->row_, cur_oc_res, - params_->col_ * sizeof(int8_t), false); + MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, params_->row_4_, cur_oc * C4NUM, params_->deep_16_, input_sums_, + cur_bias, INT8_MIN, INT8_MAX, p.output.zp_, &p.quant_multiplier, &p.left_shift, &p.right_shift, + params_->row_, cur_oc_res, params_->col_ * sizeof(int8_t), false); #else - MatMulInt8_16x4_r(a_r4x16_ptr_, cur_b, cur_c, params_->row_, cur_oc_res, d16_, params_->col_, input_sums_, cur_bias, - &p.left_shift, &p.right_shift, &p.quant_multiplier, p.output.zp_, INT8_MIN, INT8_MAX, false); + MatMulInt8_16x4_r(a_r4x16_ptr_, cur_b, cur_c, params_->row_, cur_oc_res, params_->deep_16_, params_->col_, + input_sums_, cur_bias, &p.left_shift, &p.right_shift, &p.quant_multiplier, p.output.zp_, INT8_MIN, + INT8_MAX, false); #endif return RET_OK; @@ -127,33 +154,47 @@ int MatmulInt8CPUKernel::Run() { MS_LOG(ERROR) << "Prepare failed."; return RET_ERROR; } - auto a_ptr = reinterpret_cast(in_tensors_[0]->MutableData()); - auto b_ptr = reinterpret_cast(in_tensors_[1]->MutableData()); - auto c_ptr = reinterpret_cast(out_tensors_[0]->MutableData()); + auto a_ptr = reinterpret_cast(in_tensors_[0]->data_c()); + auto c_ptr = reinterpret_cast(out_tensors_[0]->data_c()); auto a_stride = params_->row_ * params_->deep_; auto b_stride = params_->deep_ * params_->col_; auto c_stride = params_->row_ * params_->col_; + if (!params_->b_const_) { + b_c16x4_batch_ = reinterpret_cast( + ctx_->allocator->Malloc(params_->batch * params_->col_4_ * params_->deep_16_ * sizeof(int8_t))); + if (!b_c16x4_batch_) return RET_MEMORY_FAILED; + weight_bias_sums_batch_ = + reinterpret_cast(ctx_->allocator->Malloc(params_->batch * params_->col_4_ * sizeof(int))); + if (!weight_bias_sums_batch_) return RET_MEMORY_FAILED; + auto b_ptr = reinterpret_cast(in_tensors_[1]->data_c()); + for (int i = 0; i < params_->batch; ++i) { + auto cur_b = b_ptr + i * b_stride; + auto cur_b_pack = b_c16x4_batch_ + i * params_->col_4_ * params_->deep_16_; + auto cur_sums = weight_bias_sums_batch_ + i * params_->col_4_; + if (params_->b_transpose_) { + RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, params_->col_, params_->deep_); + CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, + bias_ptr_, cur_sums, ColMajor); + } else { + RowMajor2Col16x4Major(cur_b, params_->deep_, params_->col_, cur_b_pack, params_->deep_16_); + CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, + bias_ptr_, cur_sums, RowMajor); + } + } + } + for (int i = 0; i < params_->batch; ++i) { auto cur_a_ptr = a_ptr + i * a_stride; - auto cur_b_ptr = b_ptr + i * b_stride; - if (params_->a_transpose_) { - RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, d16_); + RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, params_->deep_16_); CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, ColMajor); } else { RowMajor2Row16x4MajorInt8(cur_a_ptr, a_r4x16_ptr_, params_->row_, params_->deep_); CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor); } - if (params_->b_transpose_) { - RowMajor2Row16x4MajorInt8(cur_b_ptr, b_c16x4_ptr_, params_->col_, params_->deep_); - CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, - NULL, weight_bias_sums_, ColMajor); - } else { - RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c16x4_ptr_, d16_); - CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, - NULL, weight_bias_sums_, RowMajor); - } + b_c16x4_ptr_ = b_c16x4_batch_ + i * params_->col_4_ * params_->deep_16_; + weight_bias_sums_ = weight_bias_sums_batch_ + i * params_->col_4_; c_ptr_ = c_ptr + i * c_stride; ret = ParallelLaunch(this->context_->thread_pool_, MatmulInt8Run, this, thread_count_); if (ret != RET_OK) { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h index d40c824068..95c1a1416b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h @@ -43,28 +43,32 @@ class MatmulInt8CPUKernel : public MatmulBaseCPUKernel { ctx_->allocator->Free(a_r4x16_ptr_); a_r4x16_ptr_ = nullptr; } - if (b_c16x4_ptr_ != nullptr) { - ctx_->allocator->Free(b_c16x4_ptr_); - b_c16x4_ptr_ = nullptr; - } if (input_sums_ != nullptr) { ctx_->allocator->Free(input_sums_); input_sums_ = nullptr; } - if (weight_bias_sums_ != nullptr) { - ctx_->allocator->Free(weight_bias_sums_); - weight_bias_sums_ = nullptr; + if (b_c16x4_batch_ != nullptr) { + ctx_->allocator->Free(b_c16x4_batch_); + b_c16x4_batch_ = nullptr; + } + if (weight_bias_sums_batch_ != nullptr) { + ctx_->allocator->Free(weight_bias_sums_batch_); + weight_bias_sums_batch_ = nullptr; + } + if (bias_ptr_ != nullptr) { + ctx_->allocator->Free(bias_ptr_); + bias_ptr_ = nullptr; } } MatmulQuantArg quant_params_; int8_t *a_r4x16_ptr_ = nullptr; int8_t *b_c16x4_ptr_ = nullptr; int8_t *c_ptr_ = nullptr; + int *bias_ptr_ = nullptr; int *input_sums_ = nullptr; int *weight_bias_sums_ = nullptr; - int r4_; - int c4_; - int d16_; + int8_t *b_c16x4_batch_ = nullptr; + int *weight_bias_sums_batch_ = nullptr; }; // namespace mindspore::kernel } // namespace mindspore::kernel