|
|
|
@@ -58,11 +58,11 @@ int MatMulBaseInt8Coder::InitTmpBuffer() { |
|
|
|
MatMulBaseInt8Coder::~MatMulBaseInt8Coder() { FreeQuantParam(); } |
|
|
|
|
|
|
|
void MatMulBaseInt8Coder::ResizeParameter() { |
|
|
|
param_->row_align_ = UP_ROUND(param_->row_, C4NUM); |
|
|
|
param_->col_align_ = UP_ROUND(param_->col_, C4NUM); |
|
|
|
param_->row_align_ = UP_ROUND(param_->row_, row_tile_); |
|
|
|
param_->col_align_ = UP_ROUND(param_->col_, col_tile_); |
|
|
|
param_->deep_16_ = UP_ROUND(param_->deep_, C16NUM); |
|
|
|
thread_count_ = MSMIN(param_->op_parameter_.thread_num_, UP_DIV(param_->col_align_, C4NUM)); |
|
|
|
thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, C4NUM), thread_count_); |
|
|
|
thread_count_ = MSMIN(param_->op_parameter_.thread_num_, UP_DIV(param_->col_align_, col_tile_)); |
|
|
|
thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_count_); |
|
|
|
} |
|
|
|
|
|
|
|
void MatMulBaseInt8Coder::FreeQuantParam() { |
|
|
|
@@ -138,6 +138,12 @@ int MatMulBaseInt8Coder::InitQuantParam() { |
|
|
|
void MatMulBaseInt8Coder::InitParameter() { |
|
|
|
param_->a_const_ = (input_tensor_ != nullptr); |
|
|
|
param_->b_const_ = (filter_tensor_ != nullptr); |
|
|
|
row_tile_ = C4NUM; |
|
|
|
if (target_ == kARM32A) { |
|
|
|
col_tile_ = C2NUM; |
|
|
|
} else { |
|
|
|
col_tile_ = C4NUM; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
int MatMulBaseInt8Coder::InitBias() { |
|
|
|
@@ -198,6 +204,7 @@ int MatMulBaseInt8Coder::DoCode(CoderContext *const context) { |
|
|
|
param_->deep_, param_->col_, param_->col_align_, param_->deep_16_, quant_.input_.zp_, |
|
|
|
"init_filter_zp", bias_ptr_, param_->b_transpose_, filter_per_channel_); |
|
|
|
} else { |
|
|
|
code.CodeArray("init_filter_zp", quant_.filter_zp_, weight_quant_num_, false); |
|
|
|
code.CodeFunction("InitInt8MatrixB", filter_tensor_, weight_bias_sums_, pack_b_ptr_, param_->batch, param_->deep_, |
|
|
|
param_->col_, param_->col_align_, param_->deep_16_, quant_.input_.zp_, "init_filter_zp", |
|
|
|
bias_ptr_, param_->b_transpose_, filter_per_channel_); |
|
|
|
@@ -225,7 +232,7 @@ int MatMulBaseInt8Coder::DoCode(CoderContext *const context) { |
|
|
|
std::string batch_b_ptr_str = pack_b_ptr_str + "+" + std::to_string(i * param_->col_align_ * param_->deep_16_); |
|
|
|
std::string batch_c_ptr_str = c_ptr_str + "+" + std::to_string(i * param_->row_ * param_->col_); |
|
|
|
|
|
|
|
int stride = thread_stride_ * C4NUM; |
|
|
|
int stride = thread_stride_ * col_tile_; |
|
|
|
int cur_stride = task_id * stride; |
|
|
|
int res_stride = param_->col_ - cur_stride; |
|
|
|
int cur_oc = MSMIN(stride, res_stride); |
|
|
|
|