|
|
@@ -33,7 +33,7 @@ int MatmulBaseInt8Run(void *cdata, int task_id) { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
int MatmulBaseInt8CPUKernel::RunImpl(int task_id) { |
|
|
int MatmulBaseInt8CPUKernel::RunImpl(int task_id) { |
|
|
int stride = thread_stride_ * C4NUM; |
|
|
|
|
|
|
|
|
int stride = thread_stride_ * col_tile_; |
|
|
int cur_stride = task_id * stride; |
|
|
int cur_stride = task_id * stride; |
|
|
int res_stride = param_->col_ - cur_stride; |
|
|
int res_stride = param_->col_ - cur_stride; |
|
|
int cur_oc = MSMIN(stride, res_stride); |
|
|
int cur_oc = MSMIN(stride, res_stride); |
|
|
@@ -155,16 +155,23 @@ void MatmulBaseInt8CPUKernel::InitQuantParam() { |
|
|
void MatmulBaseInt8CPUKernel::InitParameter() { |
|
|
void MatmulBaseInt8CPUKernel::InitParameter() { |
|
|
param_->a_const_ = (in_tensors_[0]->data_c() != nullptr); |
|
|
param_->a_const_ = (in_tensors_[0]->data_c() != nullptr); |
|
|
param_->b_const_ = (in_tensors_[1]->data_c() != nullptr); |
|
|
param_->b_const_ = (in_tensors_[1]->data_c() != nullptr); |
|
|
|
|
|
#ifdef ENABLE_ARM32 |
|
|
|
|
|
row_tile_ = C4NUM; |
|
|
|
|
|
col_tile_ = C2NUM; |
|
|
|
|
|
#else |
|
|
|
|
|
row_tile_ = C4NUM; |
|
|
|
|
|
col_tile_ = C4NUM; |
|
|
|
|
|
#endif |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void MatmulBaseInt8CPUKernel::ResizeParameter() { |
|
|
void MatmulBaseInt8CPUKernel::ResizeParameter() { |
|
|
param_->row_align_ = UP_ROUND(param_->row_, C4NUM); |
|
|
|
|
|
param_->col_align_ = UP_ROUND(param_->col_, C4NUM); |
|
|
|
|
|
|
|
|
param_->row_align_ = UP_ROUND(param_->row_, row_tile_); |
|
|
|
|
|
param_->col_align_ = UP_ROUND(param_->col_, col_tile_); |
|
|
param_->deep_16_ = UP_ROUND(param_->deep_, C16NUM); |
|
|
param_->deep_16_ = UP_ROUND(param_->deep_, C16NUM); |
|
|
|
|
|
|
|
|
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, C4NUM)); |
|
|
|
|
|
thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, C4NUM), thread_count_); |
|
|
|
|
|
|
|
|
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, col_tile_)); |
|
|
|
|
|
thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_count_); |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -195,11 +202,19 @@ void MatmulBaseInt8CPUKernel::TransferB() { |
|
|
auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_16_; |
|
|
auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_16_; |
|
|
auto current_sums = weight_bias_sums_ + i * param_->col_align_; |
|
|
auto current_sums = weight_bias_sums_ + i * param_->col_align_; |
|
|
if (param_->b_transpose_) { |
|
|
if (param_->b_transpose_) { |
|
|
|
|
|
#ifdef ENABLE_ARM32 |
|
|
|
|
|
RowMajor2Row2x16MajorInt8(current_weight, current_b_pack, param_->col_, param_->deep_); |
|
|
|
|
|
#else |
|
|
RowMajor2Row16x4MajorInt8(current_weight, current_b_pack, param_->col_, param_->deep_); |
|
|
RowMajor2Row16x4MajorInt8(current_weight, current_b_pack, param_->col_, param_->deep_); |
|
|
|
|
|
#endif |
|
|
CalcWeightBiasSums(current_weight, param_->deep_, param_->col_, quant_.input_.zp_, quant_.filter_zp_, bias_ptr_, |
|
|
CalcWeightBiasSums(current_weight, param_->deep_, param_->col_, quant_.input_.zp_, quant_.filter_zp_, bias_ptr_, |
|
|
current_sums, ColMajor, filter_per_channel_); |
|
|
current_sums, ColMajor, filter_per_channel_); |
|
|
} else { |
|
|
} else { |
|
|
|
|
|
#ifdef ENABLE_ARM32 |
|
|
|
|
|
RowMajor2Col16x2MajorInt8(current_weight, current_b_pack, param_->deep_, param_->col_); |
|
|
|
|
|
#else |
|
|
RowMajor2Col16x4MajorInt8(current_weight, param_->deep_, param_->col_, current_b_pack); |
|
|
RowMajor2Col16x4MajorInt8(current_weight, param_->deep_, param_->col_, current_b_pack); |
|
|
|
|
|
#endif |
|
|
CalcWeightBiasSums(current_weight, param_->deep_, param_->col_, quant_.input_.zp_, quant_.filter_zp_, bias_ptr_, |
|
|
CalcWeightBiasSums(current_weight, param_->deep_, param_->col_, quant_.input_.zp_, quant_.filter_zp_, bias_ptr_, |
|
|
current_sums, RowMajor, false); |
|
|
current_sums, RowMajor, false); |
|
|
} |
|
|
} |
|
|
|