Merge pull request !4522 from fuzhiye/tmptags/v0.7.0-beta
| @@ -171,7 +171,6 @@ OpParameter *PopulatePoolingParameter(const lite::Primitive *primitive) { | |||
| pooling_param->global_ = pooling_primitive->global(); | |||
| pooling_param->window_w_ = pooling_primitive->windowW(); | |||
| pooling_param->window_h_ = pooling_primitive->windowH(); | |||
| // todo format | |||
| auto pooling_lite_primitive = (lite::Pooling *)primitive; | |||
| MS_ASSERT(nullptr != pooling_lite_primitive); | |||
| pooling_param->pad_u_ = pooling_lite_primitive->PadUp(); | |||
| @@ -181,6 +180,8 @@ OpParameter *PopulatePoolingParameter(const lite::Primitive *primitive) { | |||
| pooling_param->stride_w_ = pooling_primitive->strideW(); | |||
| pooling_param->stride_h_ = pooling_primitive->strideH(); | |||
| auto is_global = pooling_primitive->global(); | |||
| pooling_param->global_ = is_global; | |||
| auto pool_mode = pooling_primitive->poolingMode(); | |||
| switch (pool_mode) { | |||
| case schema::PoolMode_MAX_POOLING: | |||
| @@ -76,6 +76,10 @@ int PoolingBaseCPUKernel::Init() { | |||
| pooling_param_->output_channel_ = out_tensor->Channel(); | |||
| pooling_param_->output_h_ = out_tensor->Height(); | |||
| pooling_param_->output_w_ = out_tensor->Width(); | |||
| if (pooling_param_->global_) { | |||
| pooling_param_->window_h_ = pooling_param_->input_h_; | |||
| pooling_param_->window_w_ = pooling_param_->input_w_; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -27,16 +27,125 @@ | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_MEMORY_FAILED; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Conv2D; | |||
| namespace mindspore::kernel { | |||
| int Convolution1x1FP16CPUKernel::InitMatmulParam() { | |||
| matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; | |||
| matmul_param_->col_ = conv_param_->output_channel_; | |||
| matmul_param_->deep_ = conv_param_->input_channel_; | |||
| matmul_param_->row_16_ = UP_ROUND(matmul_param_->row_, C16NUM); | |||
| matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM); | |||
| matmul_param_->act_type_ = (conv_param_->is_relu6_) ? ActType_Relu6 : ActType_No; | |||
| matmul_param_->act_type_ = (conv_param_->is_relu_) ? ActType_Relu : matmul_param_->act_type_; | |||
| return RET_OK; | |||
| } | |||
| int Convolution1x1FP16CPUKernel::InitConv1x1Param() { | |||
| pre_trans_input_ = (conv_param_->pad_h_ != 0 || conv_param_->pad_w_ != 0 || conv_param_->stride_h_ != 1 || | |||
| conv_param_->stride_w_ != 1); | |||
| if (pre_trans_input_) { | |||
| input_ptr_ = reinterpret_cast<float16_t *>(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(float16_t))); | |||
| if (input_ptr_ == nullptr) { | |||
| MS_LOG(ERROR) << "Conv1x1 Malloc input_ptr_ error!"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| memset(input_ptr_, 0, matmul_param_->row_ * matmul_param_->deep_ * sizeof(float16_t)); | |||
| } | |||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM)); | |||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM; | |||
| pack_input_ = | |||
| reinterpret_cast<float16_t *>(malloc(matmul_param_->row_16_ * matmul_param_->deep_ * sizeof(float16_t))); | |||
| if (pack_input_ == nullptr) { | |||
| MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| memset(pack_input_, 0, matmul_param_->row_16_ * matmul_param_->deep_ * sizeof(float16_t)); | |||
| return RET_OK; | |||
| } | |||
| int Convolution1x1FP16CPUKernel::InitWeightBias() { | |||
| auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Get Execute filter failed."; | |||
| return ret; | |||
| } | |||
| if (in_tensors_.size() == 3) { | |||
| bias_data_ = malloc(matmul_param_->col_8_ * sizeof(float16_t)); | |||
| if (bias_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!"; | |||
| return RET_ERROR; | |||
| } | |||
| memset(bias_data_, 0, matmul_param_->col_8_ * sizeof(float16_t)); | |||
| memcpy(bias_data_, in_tensors_[2]->Data(), conv_param_->output_channel_ * sizeof(float16_t)); | |||
| } else { | |||
| bias_data_ = nullptr; | |||
| } | |||
| weight_ptr_ = reinterpret_cast<float16_t *>(malloc(matmul_param_->deep_ * matmul_param_->col_8_ * sizeof(float16_t))); | |||
| if (weight_ptr_ == nullptr) { | |||
| MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!"; | |||
| return RET_ERROR; | |||
| } | |||
| memset(weight_ptr_, 0, matmul_param_->deep_ * matmul_param_->col_8_ * sizeof(float16_t)); | |||
| RowMajor2Col8MajorFp16(reinterpret_cast<float16_t *>(execute_weight_), weight_ptr_, matmul_param_->col_, | |||
| matmul_param_->deep_); | |||
| return RET_OK; | |||
| } | |||
| int Convolution1x1FP16CPUKernel::InitBuffer() { | |||
| /*=============================fp16_input_============================*/ | |||
| size_t fp16_input_size = conv_param_->input_channel_ * conv_param_->input_batch_ * conv_param_->input_h_ * | |||
| conv_param_->input_w_ * sizeof(float16_t); | |||
| fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size)); | |||
| if (fp16_input_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc fp16_input_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| memset(fp16_input_, 0, fp16_input_size); | |||
| /*=============================fp16_out_============================*/ | |||
| size_t fp16_output_size = conv_param_->output_channel_ * conv_param_->output_batch_ * conv_param_->output_h_ * | |||
| conv_param_->output_w_ * sizeof(float16_t); | |||
| fp16_out_ = reinterpret_cast<float16_t *>(malloc(fp16_output_size)); | |||
| if (fp16_out_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc fp16_out_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int Convolution1x1FP16CPUKernel::Init() { | |||
| auto ret = ConvolutionBaseCPUKernel::Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvolutionBase init failed."; | |||
| return ret; | |||
| } | |||
| ret = InitMatmulParam(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init matmul param failed."; | |||
| return ret; | |||
| } | |||
| ret = InitConv1x1Param(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init conv1x1 param failed."; | |||
| return ret; | |||
| } | |||
| ret = InitBuffer(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init buffer failed."; | |||
| return ret; | |||
| } | |||
| ret = InitWeightBias(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init weight bias failed."; | |||
| return ret; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -47,8 +156,14 @@ int Convolution1x1FP16CPUKernel::ReSize() { | |||
| if (fp16_input_ != nullptr) { | |||
| free(fp16_input_); | |||
| } | |||
| if (nhwc4_input_ != nullptr) { | |||
| free(nhwc4_input_); | |||
| if (fp16_weight_ != nullptr) { | |||
| free(fp16_weight_); | |||
| } | |||
| if (input_ptr_ != nullptr) { | |||
| free(input_ptr_); | |||
| } | |||
| if (weight_ptr_ != nullptr) { | |||
| free(weight_ptr_); | |||
| } | |||
| auto ret = ConvolutionBaseCPUKernel::Init(); | |||
| @@ -56,13 +171,49 @@ int Convolution1x1FP16CPUKernel::ReSize() { | |||
| MS_LOG(ERROR) << "ConvolutionBase init failed."; | |||
| return ret; | |||
| } | |||
| ret = InitMatmulParam(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init matmul param failed."; | |||
| return ret; | |||
| } | |||
| ret = InitConv1x1Param(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init conv1x1 param failed."; | |||
| return ret; | |||
| } | |||
| ret = InitBuffer(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init buffer failed."; | |||
| return ret; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void Convolution1x1FP16CPUKernel::Pre1x1Trans(float16_t *src_input, float16_t *src_output) { | |||
| output_ptr_ = src_output; | |||
| if (pre_trans_input_) { | |||
| Conv1x1InputPackFp16(src_input, input_ptr_, conv_param_); | |||
| } else { | |||
| input_ptr_ = src_input; | |||
| } | |||
| RowMajor2Col8MajorFp16(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||
| return; | |||
| } | |||
| int Convolution1x1FP16CPUKernel::RunImpl(int task_id) { | |||
| // Conv1x1Fp16(reinterpret_cast<float16_t *>(nhwc4_input_), transformed_filter_addr_, | |||
| // reinterpret_cast<float16_t *>(bias_data_), fp16_out_, tile_buffer_, block_unit_buffer_, | |||
| // tmp_dst_buffer_, tmp_out_, task_id, conv_param_); | |||
| int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_); | |||
| if (cur_oc <= 0) { | |||
| return RET_OK; | |||
| } | |||
| auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float16_t *>(bias_data_) + thread_stride_ * task_id; | |||
| MatMulFp16(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_, | |||
| output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_, | |||
| matmul_param_->row_, cur_oc, matmul_param_->col_, true); | |||
| return RET_OK; | |||
| } | |||
| @@ -83,12 +234,22 @@ int Convolution1x1FP16CPUKernel::Run() { | |||
| return RET_ERROR; | |||
| } | |||
| ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | |||
| ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Get executor tensor failed."; | |||
| return ret; | |||
| } | |||
| for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { | |||
| Pre1x1Trans( | |||
| execute_input_ + batch_index * conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_, | |||
| execute_output_ + batch_index * matmul_param_->row_ * matmul_param_->col_); | |||
| int error_code = LiteBackendParallelLaunch(Convolution1x1Fp16Impl, this, thread_count_); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "conv1x1 fp16 error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| int error_code = LiteBackendParallelLaunch(Convolution1x1Fp16Impl, this, thread_count_); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "conv1x1 fp16 error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| ConvolutionBaseFP16CPUKernel::IfCastOutput(); | |||
| @@ -22,6 +22,8 @@ | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h" | |||
| #include "src/runtime/kernel/arm/nnacl/optimized_kernel.h" | |||
| #include "src/runtime/kernel/arm/nnacl/matmul_parameter.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h" | |||
| namespace mindspore::kernel { | |||
| class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||
| @@ -29,7 +31,9 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||
| Convolution1x1FP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx, | |||
| const lite::Primitive *primitive) | |||
| : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| matmul_param_ = new MatMulParameter(); | |||
| } | |||
| ~Convolution1x1FP16CPUKernel() override { | |||
| if (fp16_input_ != nullptr) { | |||
| free(fp16_input_); | |||
| @@ -40,14 +44,34 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||
| if (fp16_out_ != nullptr) { | |||
| free(fp16_out_); | |||
| } | |||
| if (input_ptr_ != nullptr) { | |||
| free(input_ptr_); | |||
| } | |||
| if (weight_ptr_ != nullptr) { | |||
| free(weight_ptr_); | |||
| } | |||
| delete matmul_param_; | |||
| } | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int RunImpl(int task_id); | |||
| int InitBuffer(); | |||
| int InitConv1x1Param(); | |||
| int InitMatmulParam(); | |||
| int InitWeightBias(); | |||
| void Pre1x1Trans(float16_t *src_input, float16_t *src_output); | |||
| private: | |||
| bool pre_trans_input_ = false; | |||
| int thread_count_ = 0; | |||
| int thread_stride_ = 0; | |||
| float16_t *weight_ptr_ = nullptr; | |||
| float16_t *input_ptr_ = nullptr; | |||
| float16_t *pack_input_ = nullptr; | |||
| float16_t *output_ptr_ = nullptr; | |||
| MatMulParameter *matmul_param_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -62,7 +62,11 @@ int Convolution3x3FP16CPUKernel::InitWeightBias() { | |||
| return RET_ERROR; | |||
| } | |||
| memset(transformed_filter_addr_, 0, transformed_size); | |||
| ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); | |||
| auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Get Execute filter failed."; | |||
| return ret; | |||
| } | |||
| ProcessFilterFp16(execute_weight_, transformed_filter_addr_, conv_param_); | |||
| // init bias | |||
| @@ -249,8 +253,11 @@ int Convolution3x3FP16CPUKernel::Run() { | |||
| MS_LOG(ERROR) << "Prepare failed."; | |||
| return RET_ERROR; | |||
| } | |||
| ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | |||
| ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Get execute tensor failed."; | |||
| return ret; | |||
| } | |||
| int in_batch = conv_param_->input_batch_; | |||
| int in_h = conv_param_->input_h_; | |||
| int in_w = conv_param_->input_w_; | |||
| @@ -18,6 +18,7 @@ | |||
| #include <vector> | |||
| #include "src/runtime/kernel/arm/fp16/convolution_sw_fp16.h" | |||
| #include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" | |||
| #include "src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" | |||
| @@ -46,7 +47,11 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { | |||
| int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane; | |||
| // init weight | |||
| ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); | |||
| auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Get Execute filter failed."; | |||
| return ret; | |||
| } | |||
| packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t))); | |||
| if (packed_weight_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc packed_weight_ failed."; | |||
| @@ -218,7 +223,12 @@ int ConvolutionFP16CPUKernel::Run() { | |||
| MS_LOG(ERROR) << "Prepare failed."; | |||
| return RET_ERROR; | |||
| } | |||
| ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | |||
| ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Get Execute tensor failed."; | |||
| return ret; | |||
| } | |||
| int in_batch = conv_param_->input_batch_; | |||
| int in_h = conv_param_->input_h_; | |||
| @@ -256,6 +266,8 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Ten | |||
| kernel::LiteKernel *kernel = nullptr; | |||
| if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { | |||
| kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } else if (kernel_h == 1 && kernel_w == 1) { | |||
| kernel = new (std::nothrow) kernel::Convolution1x1FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } else { | |||
| bool use_winograd = false; | |||
| int out_unit; | |||
| @@ -39,7 +39,11 @@ int ConvolutionSWFP16CPUKernel::ProcessFilter() { | |||
| int out_channel = conv_param_->output_channel_; | |||
| int ic4 = UP_DIV(in_channel, C4NUM); | |||
| ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); | |||
| auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Get Execute filter failed."; | |||
| return ret; | |||
| } | |||
| for (int oc = 0; oc < out_channel; ++oc) { | |||
| int src_oc_offset = oc * kernel_h * kernel_w * in_channel; | |||
| @@ -228,7 +232,11 @@ int ConvolutionSWFP16CPUKernel::Run() { | |||
| MS_LOG(ERROR) << "Prepare failed."; | |||
| return RET_ERROR; | |||
| } | |||
| ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | |||
| ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Get Execute tensor failed."; | |||
| return ret; | |||
| } | |||
| int in_batch = conv_param_->input_batch_; | |||
| int in_h = conv_param_->input_h_; | |||
| @@ -115,7 +115,11 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { | |||
| return RET_ERROR; | |||
| } | |||
| ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); | |||
| ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Get Execute filter failed."; | |||
| return ret; | |||
| } | |||
| WinogradFilterTransformFp16(execute_weight_, trans_weight_, kernel_unit_, input_unit_, conv_param_, oc_block); | |||
| // init bias | |||
| @@ -377,7 +381,11 @@ int ConvolutionWinogradFP16CPUKernel::Run() { | |||
| return prepare_ret; | |||
| } | |||
| ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | |||
| auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Get Execute tensor failed."; | |||
| return ret; | |||
| } | |||
| int in_batch = conv_param_->input_batch_; | |||
| int in_h = conv_param_->input_h_; | |||
| @@ -16,6 +16,11 @@ | |||
| #include "nnacl/fp16/matmul_fp16.h" | |||
| void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, | |||
| int depth, int row, int col, int stride, bool write_nhwc) { | |||
| MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc); | |||
| } | |||
| void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { | |||
| size_t row16 = row / C16NUM * C16NUM; | |||
| size_t col8 = col / C8NUM * C8NUM; | |||
| @@ -134,7 +139,7 @@ void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, | |||
| : | |||
| : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) | |||
| : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", | |||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", | |||
| "v30", "v31"); | |||
| #else | |||
| for (int tr = 0; tr < C16NUM; tr++) { | |||
| @@ -29,10 +29,13 @@ | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, | |||
| int depth, int row, int col, int stride, bool write_nhwc); | |||
| void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); | |||
| #ifdef __aarch64__ | |||
| void MatmulFp16Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col); | |||
| int col, int stride, bool write_nhwc); | |||
| #endif | |||
| #ifdef __cplusplus | |||
| } | |||
| @@ -18,6 +18,27 @@ | |||
| #include <string.h> | |||
| #include <stdlib.h> | |||
| void Conv1x1InputPackFp16(const float16_t *src, float16_t *dst, ConvParameter *conv_param) { | |||
| /* support nhwc */ | |||
| for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) { | |||
| int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_h_; | |||
| if (src_h < 0 || src_h >= conv_param->input_h_) { | |||
| continue; | |||
| } | |||
| const float16_t *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_; | |||
| float16_t *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_; | |||
| for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) { | |||
| int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_w_; | |||
| if (src_w < 0 || src_w >= conv_param->input_w_) { | |||
| continue; | |||
| } | |||
| memcpy(dst_h_ptr + dst_w * conv_param->input_channel_, src_h_ptr + src_w * conv_param->input_channel_, | |||
| conv_param->input_channel_ * sizeof(float16_t)); | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num, | |||
| int block_index) { | |||
| // input format : nhwc | |||
| @@ -26,6 +26,8 @@ | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void Conv1x1InputPackFp16(const float16_t *src, float16_t *dst, ConvParameter *conv_param); | |||
| void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num, | |||
| int block_index); | |||
| @@ -26,6 +26,7 @@ typedef struct MatMulParameter { | |||
| int row_; | |||
| int col_; | |||
| int row_8_; | |||
| int row_16_; | |||
| int col_8_; | |||
| int deep_; | |||
| bool has_bias_; | |||