|
|
|
@@ -95,12 +95,22 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() { |
|
|
|
} |
|
|
|
|
|
|
|
int Convolution1x1CPUKernel::InitConv1x1Param() { |
|
|
|
int hw_tile = C12NUM; |
|
|
|
#ifdef ENABLE_ARM32 |
|
|
|
hw_tile = C4NUM; |
|
|
|
#endif |
|
|
|
if ((matmul_param_->row_ > (hw_tile * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) { |
|
|
|
multi_thread_by_hw_ = true; |
|
|
|
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, hw_tile)); |
|
|
|
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, hw_tile), thread_count_) * hw_tile; |
|
|
|
} else { |
|
|
|
multi_thread_by_hw_ = false; |
|
|
|
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM)); |
|
|
|
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM; |
|
|
|
} |
|
|
|
|
|
|
|
pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 || |
|
|
|
conv_param_->stride_w_ != 1); |
|
|
|
|
|
|
|
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM)); |
|
|
|
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM; |
|
|
|
|
|
|
|
if (pre_trans_input_) { |
|
|
|
input_ptr_ = reinterpret_cast<float *>(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(float))); |
|
|
|
if (input_ptr_ == nullptr) { |
|
|
|
@@ -113,22 +123,6 @@ int Convolution1x1CPUKernel::InitConv1x1Param() { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
void Convolution1x1CPUKernel::Pre1x1Trans(float *src_input, float *src_output) { |
|
|
|
output_ptr_ = src_output; |
|
|
|
|
|
|
|
if (pre_trans_input_) { |
|
|
|
Conv1x1InputPack(src_input, input_ptr_, conv_param_, sizeof(float)); |
|
|
|
} else { |
|
|
|
input_ptr_ = src_input; |
|
|
|
} |
|
|
|
#ifdef ENABLE_ARM32 |
|
|
|
RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); |
|
|
|
#else |
|
|
|
RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); |
|
|
|
#endif |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
int Convolution1x1CPUKernel::Init() { |
|
|
|
int error_code = InitConv1x1BiasWeight(); |
|
|
|
if (error_code != RET_OK) { |
|
|
|
@@ -164,6 +158,40 @@ int Convolution1x1Run(void *cdata, int task_id) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int Convolution1x1CPUKernel::DoConv1x1Hw(int task_id) { |
|
|
|
int res_stride = matmul_param_->row_ - task_id * thread_stride_; |
|
|
|
int cur_hw_ = MSMIN(thread_stride_, res_stride); |
|
|
|
if (cur_hw_ <= 0) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
float *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_; |
|
|
|
float *thread_pack_input = pack_input_ + task_id * thread_stride_ * matmul_param_->deep_; |
|
|
|
|
|
|
|
#ifdef ENABLE_ARM32 |
|
|
|
RowMajor2Col4Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); |
|
|
|
#else |
|
|
|
RowMajor2Col12Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); |
|
|
|
#endif |
|
|
|
|
|
|
|
float *thread_output_ptr = output_ptr_ + task_id * thread_stride_ * matmul_param_->col_; |
|
|
|
MatMulOpt(thread_pack_input, weight_ptr_, thread_output_ptr, reinterpret_cast<float *>(bias_data_), |
|
|
|
matmul_param_->act_type_, matmul_param_->deep_, cur_hw_, matmul_param_->col_, matmul_param_->col_, |
|
|
|
OutType_Nhwc); |
|
|
|
|
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int Convolution1x1RunHw(void *cdata, int task_id) { |
|
|
|
auto conv1x1 = reinterpret_cast<Convolution1x1CPUKernel *>(cdata); |
|
|
|
auto error_code = conv1x1->DoConv1x1Hw(task_id); |
|
|
|
if (error_code != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Convolution1x1Run error task_id[" << task_id << "] error_code[" << error_code << "]"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int Convolution1x1CPUKernel::Run() { |
|
|
|
auto prepare_ret = Prepare(); |
|
|
|
if (prepare_ret != RET_OK) { |
|
|
|
@@ -186,13 +214,23 @@ int Convolution1x1CPUKernel::Run() { |
|
|
|
} |
|
|
|
|
|
|
|
for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { |
|
|
|
Pre1x1Trans(src_in + batch_index * conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_, |
|
|
|
src_out + batch_index * matmul_param_->row_ * matmul_param_->col_); |
|
|
|
output_ptr_ = src_out + batch_index * matmul_param_->row_ * matmul_param_->col_; |
|
|
|
auto tmp_in = src_in + batch_index * conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_; |
|
|
|
if (pre_trans_input_) { |
|
|
|
Conv1x1InputPack(tmp_in, input_ptr_, conv_param_, sizeof(float)); |
|
|
|
} else { |
|
|
|
input_ptr_ = tmp_in; |
|
|
|
} |
|
|
|
|
|
|
|
int error_code = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Run, this, thread_count_); |
|
|
|
if (error_code != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "conv1x1 strassen error error_code[" << error_code << "]"; |
|
|
|
return RET_ERROR; |
|
|
|
if (multi_thread_by_hw_) { |
|
|
|
ParallelLaunch(this->context_->thread_pool_, Convolution1x1RunHw, this, thread_count_); |
|
|
|
} else { |
|
|
|
#ifdef ENABLE_ARM32 |
|
|
|
RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); |
|
|
|
#else |
|
|
|
RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); |
|
|
|
#endif |
|
|
|
ParallelLaunch(this->context_->thread_pool_, Convolution1x1Run, this, thread_count_); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|