|
|
|
@@ -45,26 +45,52 @@ int MatmulInt8CPUKernel::ReSize() { |
|
|
|
params_->row_ = o_shape[o_shape.size() - 2]; |
|
|
|
params_->col_ = o_shape[o_shape.size() - 1]; |
|
|
|
params_->deep_ = params_->a_transpose_ ? x_shape[x_shape.size() - 2] : x_shape[x_shape.size() - 1]; |
|
|
|
params_->row_8_ = UP_ROUND(params_->row_, 8); |
|
|
|
params_->col_8_ = UP_ROUND(params_->col_, 8); |
|
|
|
|
|
|
|
r4_ = UP_ROUND(params_->row_, 4); |
|
|
|
c4_ = UP_ROUND(params_->col_, 4); |
|
|
|
d16_ = UP_ROUND(params_->deep_, 16); |
|
|
|
a_r4x16_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t))); |
|
|
|
params_->row_4_ = UP_ROUND(params_->row_, 4); |
|
|
|
params_->col_4_ = UP_ROUND(params_->col_, 4); |
|
|
|
params_->deep_16_ = UP_ROUND(params_->deep_, 16); |
|
|
|
a_r4x16_ptr_ = |
|
|
|
reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(params_->row_4_ * params_->deep_16_ * sizeof(int8_t))); |
|
|
|
if (!a_r4x16_ptr_) return RET_MEMORY_FAILED; |
|
|
|
memset(a_r4x16_ptr_, 0, r4_ * d16_ * sizeof(int8_t)); |
|
|
|
b_c16x4_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t))); |
|
|
|
if (!b_c16x4_ptr_) return RET_MEMORY_FAILED; |
|
|
|
memset(b_c16x4_ptr_, 0, c4_ * d16_ * sizeof(int8_t)); |
|
|
|
input_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(r4_ * sizeof(int))); |
|
|
|
memset(a_r4x16_ptr_, 0, params_->row_4_ * params_->deep_16_ * sizeof(int8_t)); |
|
|
|
input_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(params_->row_4_ * sizeof(int))); |
|
|
|
if (!input_sums_) return RET_MEMORY_FAILED; |
|
|
|
memset(input_sums_, 0, r4_ * sizeof(int)); |
|
|
|
weight_bias_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(c4_ * sizeof(int))); |
|
|
|
if (!weight_bias_sums_) return RET_MEMORY_FAILED; |
|
|
|
memset(weight_bias_sums_, 0, c4_ * sizeof(int)); |
|
|
|
thread_count_ = MSMIN(thread_count_, UP_DIV(c4_, 4)); |
|
|
|
thread_stride_ = UP_DIV(UP_DIV(c4_, 4), thread_count_); |
|
|
|
memset(input_sums_, 0, params_->row_4_ * sizeof(int)); |
|
|
|
if (in_tensors_.size() == 3) { |
|
|
|
auto bias_size = params_->col_4_ * sizeof(int); |
|
|
|
bias_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(bias_size)); |
|
|
|
if (!bias_ptr_) return RET_MEMORY_FAILED; |
|
|
|
memcpy(bias_ptr_, in_tensors_[2]->data_c(), bias_size); |
|
|
|
} else { |
|
|
|
bias_ptr_ = NULL; |
|
|
|
} |
|
|
|
|
|
|
|
params_->b_const_ = (in_tensors_[1]->data_c() != nullptr); |
|
|
|
if (params_->b_const_) { |
|
|
|
b_c16x4_batch_ = reinterpret_cast<int8_t *>( |
|
|
|
ctx_->allocator->Malloc(params_->batch * params_->col_4_ * params_->deep_16_ * sizeof(int8_t))); |
|
|
|
if (!b_c16x4_batch_) return RET_MEMORY_FAILED; |
|
|
|
weight_bias_sums_batch_ = |
|
|
|
reinterpret_cast<int *>(ctx_->allocator->Malloc(params_->batch * params_->col_4_ * sizeof(int))); |
|
|
|
if (!weight_bias_sums_batch_) return RET_MEMORY_FAILED; |
|
|
|
auto b_ptr = reinterpret_cast<int8_t *>(in_tensors_[1]->data_c()); |
|
|
|
|
|
|
|
for (int i = 0; i < params_->batch; ++i) { |
|
|
|
auto cur_b = b_ptr + i * params_->deep_ * params_->col_; |
|
|
|
auto cur_b_pack = b_c16x4_batch_ + i * params_->col_4_ * params_->deep_16_; |
|
|
|
auto cur_sums = weight_bias_sums_batch_ + i * params_->col_4_; |
|
|
|
if (params_->b_transpose_) { |
|
|
|
RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, params_->col_, params_->deep_); |
|
|
|
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, |
|
|
|
bias_ptr_, cur_sums, ColMajor); |
|
|
|
} else { |
|
|
|
RowMajor2Col16x4Major(cur_b, params_->deep_, params_->col_, cur_b_pack, params_->deep_16_); |
|
|
|
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, |
|
|
|
bias_ptr_, cur_sums, RowMajor); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_4_, 4)); |
|
|
|
thread_stride_ = UP_DIV(UP_DIV(params_->col_4_, 4), thread_count_); |
|
|
|
|
|
|
|
auto input_tensor = in_tensors_[0]; |
|
|
|
auto params = input_tensor->GetQuantParams(); |
|
|
|
@@ -89,23 +115,24 @@ int MatmulInt8CPUKernel::ReSize() { |
|
|
|
} |
|
|
|
|
|
|
|
int MatmulInt8CPUKernel::RunImpl(int task_id) { |
|
|
|
int cur_oc = MSMIN(thread_stride_, UP_DIV(c4_, 4) - task_id * thread_stride_); |
|
|
|
int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_4_, 4) - task_id * thread_stride_); |
|
|
|
if (cur_oc <= 0) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
int cur_oc_res = MSMIN(thread_stride_ * C4NUM, params_->col_ - task_id * thread_stride_ * C4NUM); |
|
|
|
auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * 4 * d16_; |
|
|
|
auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * 4 * params_->deep_16_; |
|
|
|
auto cur_bias = weight_bias_sums_ + task_id * thread_stride_ * 4; |
|
|
|
auto cur_c = c_ptr_ + task_id * thread_stride_ * 4; |
|
|
|
|
|
|
|
auto &p = quant_params_; |
|
|
|
#ifdef ENABLE_ARM64 |
|
|
|
MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, INT8_MIN, INT8_MAX, |
|
|
|
p.output.zp_, &p.quant_multiplier, &p.left_shift, &p.right_shift, params_->row_, cur_oc_res, |
|
|
|
params_->col_ * sizeof(int8_t), false); |
|
|
|
MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, params_->row_4_, cur_oc * C4NUM, params_->deep_16_, input_sums_, |
|
|
|
cur_bias, INT8_MIN, INT8_MAX, p.output.zp_, &p.quant_multiplier, &p.left_shift, &p.right_shift, |
|
|
|
params_->row_, cur_oc_res, params_->col_ * sizeof(int8_t), false); |
|
|
|
#else |
|
|
|
MatMulInt8_16x4_r(a_r4x16_ptr_, cur_b, cur_c, params_->row_, cur_oc_res, d16_, params_->col_, input_sums_, cur_bias, |
|
|
|
&p.left_shift, &p.right_shift, &p.quant_multiplier, p.output.zp_, INT8_MIN, INT8_MAX, false); |
|
|
|
MatMulInt8_16x4_r(a_r4x16_ptr_, cur_b, cur_c, params_->row_, cur_oc_res, params_->deep_16_, params_->col_, |
|
|
|
input_sums_, cur_bias, &p.left_shift, &p.right_shift, &p.quant_multiplier, p.output.zp_, INT8_MIN, |
|
|
|
INT8_MAX, false); |
|
|
|
#endif |
|
|
|
|
|
|
|
return RET_OK; |
|
|
|
@@ -127,33 +154,47 @@ int MatmulInt8CPUKernel::Run() { |
|
|
|
MS_LOG(ERROR) << "Prepare failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto a_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->MutableData()); |
|
|
|
auto b_ptr = reinterpret_cast<int8_t *>(in_tensors_[1]->MutableData()); |
|
|
|
auto c_ptr = reinterpret_cast<int8_t *>(out_tensors_[0]->MutableData()); |
|
|
|
auto a_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c()); |
|
|
|
auto c_ptr = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c()); |
|
|
|
auto a_stride = params_->row_ * params_->deep_; |
|
|
|
auto b_stride = params_->deep_ * params_->col_; |
|
|
|
auto c_stride = params_->row_ * params_->col_; |
|
|
|
|
|
|
|
if (!params_->b_const_) { |
|
|
|
b_c16x4_batch_ = reinterpret_cast<int8_t *>( |
|
|
|
ctx_->allocator->Malloc(params_->batch * params_->col_4_ * params_->deep_16_ * sizeof(int8_t))); |
|
|
|
if (!b_c16x4_batch_) return RET_MEMORY_FAILED; |
|
|
|
weight_bias_sums_batch_ = |
|
|
|
reinterpret_cast<int *>(ctx_->allocator->Malloc(params_->batch * params_->col_4_ * sizeof(int))); |
|
|
|
if (!weight_bias_sums_batch_) return RET_MEMORY_FAILED; |
|
|
|
auto b_ptr = reinterpret_cast<int8_t *>(in_tensors_[1]->data_c()); |
|
|
|
for (int i = 0; i < params_->batch; ++i) { |
|
|
|
auto cur_b = b_ptr + i * b_stride; |
|
|
|
auto cur_b_pack = b_c16x4_batch_ + i * params_->col_4_ * params_->deep_16_; |
|
|
|
auto cur_sums = weight_bias_sums_batch_ + i * params_->col_4_; |
|
|
|
if (params_->b_transpose_) { |
|
|
|
RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, params_->col_, params_->deep_); |
|
|
|
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, |
|
|
|
bias_ptr_, cur_sums, ColMajor); |
|
|
|
} else { |
|
|
|
RowMajor2Col16x4Major(cur_b, params_->deep_, params_->col_, cur_b_pack, params_->deep_16_); |
|
|
|
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, |
|
|
|
bias_ptr_, cur_sums, RowMajor); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (int i = 0; i < params_->batch; ++i) { |
|
|
|
auto cur_a_ptr = a_ptr + i * a_stride; |
|
|
|
auto cur_b_ptr = b_ptr + i * b_stride; |
|
|
|
|
|
|
|
if (params_->a_transpose_) { |
|
|
|
RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, d16_); |
|
|
|
RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, params_->deep_16_); |
|
|
|
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, ColMajor); |
|
|
|
} else { |
|
|
|
RowMajor2Row16x4MajorInt8(cur_a_ptr, a_r4x16_ptr_, params_->row_, params_->deep_); |
|
|
|
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor); |
|
|
|
} |
|
|
|
if (params_->b_transpose_) { |
|
|
|
RowMajor2Row16x4MajorInt8(cur_b_ptr, b_c16x4_ptr_, params_->col_, params_->deep_); |
|
|
|
CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, |
|
|
|
NULL, weight_bias_sums_, ColMajor); |
|
|
|
} else { |
|
|
|
RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c16x4_ptr_, d16_); |
|
|
|
CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, |
|
|
|
NULL, weight_bias_sums_, RowMajor); |
|
|
|
} |
|
|
|
b_c16x4_ptr_ = b_c16x4_batch_ + i * params_->col_4_ * params_->deep_16_; |
|
|
|
weight_bias_sums_ = weight_bias_sums_batch_ + i * params_->col_4_; |
|
|
|
c_ptr_ = c_ptr + i * c_stride; |
|
|
|
ret = ParallelLaunch(this->context_->thread_pool_, MatmulInt8Run, this, thread_count_); |
|
|
|
if (ret != RET_OK) { |
|
|
|
|