diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc index 942c11fd07..c817d04168 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc @@ -170,36 +170,28 @@ int Convolution1x1Int8CPUKernel::InitParam() { matmul_param_->deep_4_ = UP_ROUND(matmul_param_->deep_, C4NUM); matmul_param_->deep_16_ = UP_ROUND(matmul_param_->deep_, C16NUM); - /* init input sum size */ + int row_pack_count = 0; + int col_pack_count = 0; if (support_optimize_) { - if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { - input_sum_size = UP_ROUND(conv_param_->output_channel_, C8NUM) * UP_ROUND(matmul_param_->row_, C8NUM); - } else { - input_sum_size = UP_ROUND(matmul_param_->row_, C8NUM); - } + row_pack_count = C8NUM; + col_pack_count = C8NUM; } else { - if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { - input_sum_size = UP_ROUND(conv_param_->output_channel_, C4NUM) * UP_ROUND(matmul_param_->row_, C4NUM); - } else { - input_sum_size = UP_ROUND(matmul_param_->row_, C4NUM); - } + row_pack_count = C4NUM; + col_pack_count = C4NUM; } - if (support_optimize_) { - thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM)); - thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_); + /* init input sum size */ + if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { + input_sum_size = UP_ROUND(matmul_param_->col_, col_pack_count) * UP_ROUND(matmul_param_->row_, row_pack_count); } else { - thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C4NUM)); - thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C4NUM), thread_count_); + input_sum_size = UP_ROUND(matmul_param_->row_, row_pack_count); } - if (support_optimize_) { - thread_count_hw_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, C8NUM)); - thread_stride_hw_ = UP_DIV(UP_DIV(matmul_param_->row_, C8NUM), thread_count_hw_); - } else { - thread_count_hw_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, C4NUM)); - thread_stride_hw_ = UP_DIV(UP_DIV(matmul_param_->row_, C4NUM), thread_count_hw_); - } + thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, row_pack_count)); + thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, row_pack_count), thread_count_); + + thread_count_hw_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, col_pack_count)); + thread_stride_hw_ = UP_DIV(UP_DIV(matmul_param_->row_, col_pack_count), thread_count_hw_); if (pre_trans_input_) { input_ptr_ = reinterpret_cast(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(int8_t))); @@ -296,7 +288,7 @@ int Convolution1x1Int8Impl(void *cdata, int task_id) { } int Convolution1x1Int8CPUKernel::InitRunBuf() { - input_sum_ = reinterpret_cast(malloc(input_sum_size * sizeof(int32_t))); + input_sum_ = reinterpret_cast(ctx_->allocator->Malloc(input_sum_size * sizeof(int32_t))); if (input_sum_ == nullptr) { MS_LOG(ERROR) << "malloc input_sum_ failed."; return RET_ERROR; @@ -334,6 +326,7 @@ int Convolution1x1Int8CPUKernel::Run() { int error_code = InitRunBuf(); if (error_code != RET_OK) { MS_LOG(ERROR) << "conv1x1 int8 InitRunBuf error_code[" << error_code << "]"; + FreeRunBuf(); return RET_ERROR; }