|
|
|
@@ -82,8 +82,8 @@ int DeConvolutionFp16CPUKernel::InitParam() { |
|
|
|
matmul_param_->row_ = input_plane_; |
|
|
|
matmul_param_->deep_ = conv_param_->input_channel_; |
|
|
|
matmul_param_->col_ = conv_param_->output_channel_ * kernel_plane_; |
|
|
|
row16_ = UP_ROUND(matmul_param_->row_, 16); |
|
|
|
col8_ = UP_ROUND(matmul_param_->col_, 8); |
|
|
|
matmul_param_->row_16_ = UP_ROUND(matmul_param_->row_, C16NUM); |
|
|
|
matmul_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * kernel_plane_; |
|
|
|
|
|
|
|
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(conv_param_->output_channel_, C8NUM)); |
|
|
|
thread_stride_ = UP_DIV(UP_DIV(conv_param_->output_channel_, C8NUM), thread_count_); |
|
|
|
@@ -98,13 +98,15 @@ int DeConvolutionFp16CPUKernel::InitRunBuf() { |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
|
|
|
|
tmp_buffer_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(row16_ * col8_ * sizeof(float16_t))); |
|
|
|
tmp_buffer_ = reinterpret_cast<float16_t *>( |
|
|
|
ctx_->allocator->Malloc(matmul_param_->row_16_ * matmul_param_->col_8_ * sizeof(float16_t))); |
|
|
|
if (tmp_buffer_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "deconv Malloc tmp_buffer_ error!"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
pack_input_ = reinterpret_cast<float16_t *>(malloc(row16_ * matmul_param_->deep_ * sizeof(float16_t))); |
|
|
|
pack_input_ = |
|
|
|
reinterpret_cast<float16_t *>(malloc(matmul_param_->row_16_ * matmul_param_->deep_ * sizeof(float16_t))); |
|
|
|
if (pack_input_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "deconv Malloc pack_input_ error!"; |
|
|
|
return RET_ERROR; |
|
|
|
@@ -147,7 +149,7 @@ int DeConvolutionFp16CPUKernel::DoDeconv(int task_id) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
auto tmp_buf = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * row16_; |
|
|
|
auto tmp_buf = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_16_; |
|
|
|
MatMulFp16(pack_input_, execute_weight_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, |
|
|
|
tmp_buf, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_, oc * C8NUM * kernel_plane_, 0, |
|
|
|
false); |
|
|
|
|