diff --git a/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc b/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc index 48d8b21ea7..5dfe110627 100644 --- a/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc +++ b/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc @@ -58,11 +58,11 @@ int MatMulBaseInt8Coder::InitTmpBuffer() { MatMulBaseInt8Coder::~MatMulBaseInt8Coder() { FreeQuantParam(); } void MatMulBaseInt8Coder::ResizeParameter() { - param_->row_align_ = UP_ROUND(param_->row_, C4NUM); - param_->col_align_ = UP_ROUND(param_->col_, C4NUM); + param_->row_align_ = UP_ROUND(param_->row_, row_tile_); + param_->col_align_ = UP_ROUND(param_->col_, col_tile_); param_->deep_16_ = UP_ROUND(param_->deep_, C16NUM); - thread_count_ = MSMIN(param_->op_parameter_.thread_num_, UP_DIV(param_->col_align_, C4NUM)); - thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, C4NUM), thread_count_); + thread_count_ = MSMIN(param_->op_parameter_.thread_num_, UP_DIV(param_->col_align_, col_tile_)); + thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_count_); } void MatMulBaseInt8Coder::FreeQuantParam() { @@ -138,6 +138,12 @@ int MatMulBaseInt8Coder::InitQuantParam() { void MatMulBaseInt8Coder::InitParameter() { param_->a_const_ = (input_tensor_ != nullptr); param_->b_const_ = (filter_tensor_ != nullptr); + row_tile_ = C4NUM; + if (target_ == kARM32A) { + col_tile_ = C2NUM; + } else { + col_tile_ = C4NUM; + } } int MatMulBaseInt8Coder::InitBias() { @@ -198,6 +204,7 @@ int MatMulBaseInt8Coder::DoCode(CoderContext *const context) { param_->deep_, param_->col_, param_->col_align_, param_->deep_16_, quant_.input_.zp_, "init_filter_zp", bias_ptr_, param_->b_transpose_, filter_per_channel_); } else { + code.CodeArray("init_filter_zp", quant_.filter_zp_, weight_quant_num_, false); code.CodeFunction("InitInt8MatrixB", filter_tensor_, weight_bias_sums_, pack_b_ptr_, param_->batch, param_->deep_, param_->col_, param_->col_align_, param_->deep_16_, quant_.input_.zp_, "init_filter_zp", bias_ptr_, param_->b_transpose_, filter_per_channel_); @@ -225,7 +232,7 @@ int MatMulBaseInt8Coder::DoCode(CoderContext *const context) { std::string batch_b_ptr_str = pack_b_ptr_str + "+" + std::to_string(i * param_->col_align_ * param_->deep_16_); std::string batch_c_ptr_str = c_ptr_str + "+" + std::to_string(i * param_->row_ * param_->col_); - int stride = thread_stride_ * C4NUM; + int stride = thread_stride_ * col_tile_; int cur_stride = task_id * stride; int res_stride = param_->col_ - cur_stride; int cur_oc = MSMIN(stride, res_stride); diff --git a/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.h b/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.h index 5f0dbd09de..c893324df3 100644 --- a/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.h +++ b/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.h @@ -65,6 +65,8 @@ class MatMulBaseInt8Coder : public OperatorCoder { private: int weight_quant_num_{0}; + int row_tile_{C4NUM}; + int col_tile_{C4NUM}; }; } // namespace mindspore::lite::micro::nnacl #endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_BASE_INT8_CODER_H_ diff --git a/mindspore/lite/micro/coder/operator_library/wrapper/int8/matmul_int8_wrapper.c b/mindspore/lite/micro/coder/operator_library/wrapper/int8/matmul_int8_wrapper.c index c51d3cafb3..c3c16c3fb4 100644 --- a/mindspore/lite/micro/coder/operator_library/wrapper/int8/matmul_int8_wrapper.c +++ b/mindspore/lite/micro/coder/operator_library/wrapper/int8/matmul_int8_wrapper.c @@ -38,10 +38,18 @@ void InitInt8MatrixB(int8_t *weight_ptr, int32_t *weight_bias_sums_batch_, int8_ int8_t *cur_b_pack = dst_ptr + i * col_align * deep_16; int32_t *cur_sums = weight_bias_sums_batch_ + i * col_align; if (b_transpose) { +#ifdef ENABLE_ARM32 + RowMajor2Row2x16MajorInt8(cur_b, cur_b_pack, col, deep); +#else RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, col, deep); +#endif CalcWeightBiasSums(cur_b, deep, col, input_zp, weight_zp, bias_ptr, cur_sums, ColMajor, filter_per_channel); } else { +#ifdef ENABLE_ARM32 + RowMajor2Col16x2MajorInt8(cur_b, cur_b_pack, deep, col); +#else RowMajor2Col16x4MajorInt8(cur_b, deep, col, cur_b_pack); +#endif CalcWeightBiasSums(cur_b, deep, col, input_zp, weight_zp, bias_ptr, cur_sums, RowMajor, false); } }