Browse Source

!14206 fix fc &&matmul

From: @zoloft
Reviewed-by: @wangchengyuan,@hangangqiang
Signed-off-by: @wangchengyuan
pull/14206/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
ccf5b3d808
3 changed files with 22 additions and 5 deletions
  1. +12
    -5
      mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc
  2. +2
    -0
      mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.h
  3. +8
    -0
      mindspore/lite/micro/coder/operator_library/wrapper/int8/matmul_int8_wrapper.c

+ 12
- 5
mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc View File

@@ -58,11 +58,11 @@ int MatMulBaseInt8Coder::InitTmpBuffer() {
MatMulBaseInt8Coder::~MatMulBaseInt8Coder() { FreeQuantParam(); } MatMulBaseInt8Coder::~MatMulBaseInt8Coder() { FreeQuantParam(); }


void MatMulBaseInt8Coder::ResizeParameter() { void MatMulBaseInt8Coder::ResizeParameter() {
param_->row_align_ = UP_ROUND(param_->row_, C4NUM);
param_->col_align_ = UP_ROUND(param_->col_, C4NUM);
param_->row_align_ = UP_ROUND(param_->row_, row_tile_);
param_->col_align_ = UP_ROUND(param_->col_, col_tile_);
param_->deep_16_ = UP_ROUND(param_->deep_, C16NUM); param_->deep_16_ = UP_ROUND(param_->deep_, C16NUM);
thread_count_ = MSMIN(param_->op_parameter_.thread_num_, UP_DIV(param_->col_align_, C4NUM));
thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, C4NUM), thread_count_);
thread_count_ = MSMIN(param_->op_parameter_.thread_num_, UP_DIV(param_->col_align_, col_tile_));
thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_count_);
} }


void MatMulBaseInt8Coder::FreeQuantParam() { void MatMulBaseInt8Coder::FreeQuantParam() {
@@ -138,6 +138,12 @@ int MatMulBaseInt8Coder::InitQuantParam() {
void MatMulBaseInt8Coder::InitParameter() { void MatMulBaseInt8Coder::InitParameter() {
param_->a_const_ = (input_tensor_ != nullptr); param_->a_const_ = (input_tensor_ != nullptr);
param_->b_const_ = (filter_tensor_ != nullptr); param_->b_const_ = (filter_tensor_ != nullptr);
row_tile_ = C4NUM;
if (target_ == kARM32A) {
col_tile_ = C2NUM;
} else {
col_tile_ = C4NUM;
}
} }


int MatMulBaseInt8Coder::InitBias() { int MatMulBaseInt8Coder::InitBias() {
@@ -198,6 +204,7 @@ int MatMulBaseInt8Coder::DoCode(CoderContext *const context) {
param_->deep_, param_->col_, param_->col_align_, param_->deep_16_, quant_.input_.zp_, param_->deep_, param_->col_, param_->col_align_, param_->deep_16_, quant_.input_.zp_,
"init_filter_zp", bias_ptr_, param_->b_transpose_, filter_per_channel_); "init_filter_zp", bias_ptr_, param_->b_transpose_, filter_per_channel_);
} else { } else {
code.CodeArray("init_filter_zp", quant_.filter_zp_, weight_quant_num_, false);
code.CodeFunction("InitInt8MatrixB", filter_tensor_, weight_bias_sums_, pack_b_ptr_, param_->batch, param_->deep_, code.CodeFunction("InitInt8MatrixB", filter_tensor_, weight_bias_sums_, pack_b_ptr_, param_->batch, param_->deep_,
param_->col_, param_->col_align_, param_->deep_16_, quant_.input_.zp_, "init_filter_zp", param_->col_, param_->col_align_, param_->deep_16_, quant_.input_.zp_, "init_filter_zp",
bias_ptr_, param_->b_transpose_, filter_per_channel_); bias_ptr_, param_->b_transpose_, filter_per_channel_);
@@ -225,7 +232,7 @@ int MatMulBaseInt8Coder::DoCode(CoderContext *const context) {
std::string batch_b_ptr_str = pack_b_ptr_str + "+" + std::to_string(i * param_->col_align_ * param_->deep_16_); std::string batch_b_ptr_str = pack_b_ptr_str + "+" + std::to_string(i * param_->col_align_ * param_->deep_16_);
std::string batch_c_ptr_str = c_ptr_str + "+" + std::to_string(i * param_->row_ * param_->col_); std::string batch_c_ptr_str = c_ptr_str + "+" + std::to_string(i * param_->row_ * param_->col_);


int stride = thread_stride_ * C4NUM;
int stride = thread_stride_ * col_tile_;
int cur_stride = task_id * stride; int cur_stride = task_id * stride;
int res_stride = param_->col_ - cur_stride; int res_stride = param_->col_ - cur_stride;
int cur_oc = MSMIN(stride, res_stride); int cur_oc = MSMIN(stride, res_stride);


+ 2
- 0
mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.h View File

@@ -65,6 +65,8 @@ class MatMulBaseInt8Coder : public OperatorCoder {


private: private:
int weight_quant_num_{0}; int weight_quant_num_{0};
int row_tile_{C4NUM};
int col_tile_{C4NUM};
}; };
} // namespace mindspore::lite::micro::nnacl } // namespace mindspore::lite::micro::nnacl
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_BASE_INT8_CODER_H_ #endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_BASE_INT8_CODER_H_

+ 8
- 0
mindspore/lite/micro/coder/operator_library/wrapper/int8/matmul_int8_wrapper.c View File

@@ -38,10 +38,18 @@ void InitInt8MatrixB(int8_t *weight_ptr, int32_t *weight_bias_sums_batch_, int8_
int8_t *cur_b_pack = dst_ptr + i * col_align * deep_16; int8_t *cur_b_pack = dst_ptr + i * col_align * deep_16;
int32_t *cur_sums = weight_bias_sums_batch_ + i * col_align; int32_t *cur_sums = weight_bias_sums_batch_ + i * col_align;
if (b_transpose) { if (b_transpose) {
#ifdef ENABLE_ARM32
RowMajor2Row2x16MajorInt8(cur_b, cur_b_pack, col, deep);
#else
RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, col, deep); RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, col, deep);
#endif
CalcWeightBiasSums(cur_b, deep, col, input_zp, weight_zp, bias_ptr, cur_sums, ColMajor, filter_per_channel); CalcWeightBiasSums(cur_b, deep, col, input_zp, weight_zp, bias_ptr, cur_sums, ColMajor, filter_per_channel);
} else { } else {
#ifdef ENABLE_ARM32
RowMajor2Col16x2MajorInt8(cur_b, cur_b_pack, deep, col);
#else
RowMajor2Col16x4MajorInt8(cur_b, deep, col, cur_b_pack); RowMajor2Col16x4MajorInt8(cur_b, deep, col, cur_b_pack);
#endif
CalcWeightBiasSums(cur_b, deep, col, input_zp, weight_zp, bias_ptr, cur_sums, RowMajor, false); CalcWeightBiasSums(cur_b, deep, col, input_zp, weight_zp, bias_ptr, cur_sums, RowMajor, false);
} }
} }


Loading…
Cancel
Save