diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 40840f8424..f31d50603c 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -171,7 +171,6 @@ OpParameter *PopulatePoolingParameter(const lite::Primitive *primitive) { pooling_param->global_ = pooling_primitive->global(); pooling_param->window_w_ = pooling_primitive->windowW(); pooling_param->window_h_ = pooling_primitive->windowH(); - // todo format auto pooling_lite_primitive = (lite::Pooling *)primitive; MS_ASSERT(nullptr != pooling_lite_primitive); pooling_param->pad_u_ = pooling_lite_primitive->PadUp(); @@ -181,6 +180,8 @@ OpParameter *PopulatePoolingParameter(const lite::Primitive *primitive) { pooling_param->stride_w_ = pooling_primitive->strideW(); pooling_param->stride_h_ = pooling_primitive->strideH(); + auto is_global = pooling_primitive->global(); + pooling_param->global_ = is_global; auto pool_mode = pooling_primitive->poolingMode(); switch (pool_mode) { case schema::PoolMode_MAX_POOLING: diff --git a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc index 010467116d..c1decb1082 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc @@ -76,6 +76,10 @@ int PoolingBaseCPUKernel::Init() { pooling_param_->output_channel_ = out_tensor->Channel(); pooling_param_->output_h_ = out_tensor->Height(); pooling_param_->output_w_ = out_tensor->Width(); + if (pooling_param_->global_) { + pooling_param_->window_h_ = pooling_param_->input_h_; + pooling_param_->window_w_ = pooling_param_->input_w_; + } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc index 529193ef5f..5d4e361f64 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc @@ -27,16 +27,125 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_MEMORY_FAILED; using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Conv2D; namespace mindspore::kernel { +int Convolution1x1FP16CPUKernel::InitMatmulParam() { + matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; + matmul_param_->col_ = conv_param_->output_channel_; + matmul_param_->deep_ = conv_param_->input_channel_; + matmul_param_->row_16_ = UP_ROUND(matmul_param_->row_, C16NUM); + matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM); + matmul_param_->act_type_ = (conv_param_->is_relu6_) ? ActType_Relu6 : ActType_No; + matmul_param_->act_type_ = (conv_param_->is_relu_) ? ActType_Relu : matmul_param_->act_type_; + return RET_OK; +} + +int Convolution1x1FP16CPUKernel::InitConv1x1Param() { + pre_trans_input_ = (conv_param_->pad_h_ != 0 || conv_param_->pad_w_ != 0 || conv_param_->stride_h_ != 1 || + conv_param_->stride_w_ != 1); + if (pre_trans_input_) { + input_ptr_ = reinterpret_cast(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(float16_t))); + if (input_ptr_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc input_ptr_ error!"; + return RET_MEMORY_FAILED; + } + memset(input_ptr_, 0, matmul_param_->row_ * matmul_param_->deep_ * sizeof(float16_t)); + } + + thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM)); + thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM; + + pack_input_ = + reinterpret_cast(malloc(matmul_param_->row_16_ * matmul_param_->deep_ * sizeof(float16_t))); + if (pack_input_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!"; + return RET_MEMORY_FAILED; + } + memset(pack_input_, 0, matmul_param_->row_16_ * matmul_param_->deep_ * sizeof(float16_t)); + return RET_OK; +} + +int Convolution1x1FP16CPUKernel::InitWeightBias() { + auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get Execute filter failed."; + return ret; + } + if (in_tensors_.size() == 3) { + bias_data_ = malloc(matmul_param_->col_8_ * sizeof(float16_t)); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!"; + return RET_ERROR; + } + memset(bias_data_, 0, matmul_param_->col_8_ * sizeof(float16_t)); + memcpy(bias_data_, in_tensors_[2]->Data(), conv_param_->output_channel_ * sizeof(float16_t)); + } else { + bias_data_ = nullptr; + } + + weight_ptr_ = reinterpret_cast(malloc(matmul_param_->deep_ * matmul_param_->col_8_ * sizeof(float16_t))); + if (weight_ptr_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!"; + return RET_ERROR; + } + memset(weight_ptr_, 0, matmul_param_->deep_ * matmul_param_->col_8_ * sizeof(float16_t)); + RowMajor2Col8MajorFp16(reinterpret_cast(execute_weight_), weight_ptr_, matmul_param_->col_, + matmul_param_->deep_); + + return RET_OK; +} + +int Convolution1x1FP16CPUKernel::InitBuffer() { + /*=============================fp16_input_============================*/ + size_t fp16_input_size = conv_param_->input_channel_ * conv_param_->input_batch_ * conv_param_->input_h_ * + conv_param_->input_w_ * sizeof(float16_t); + fp16_input_ = reinterpret_cast(malloc(fp16_input_size)); + if (fp16_input_ == nullptr) { + MS_LOG(ERROR) << "malloc fp16_input_ failed."; + return RET_ERROR; + } + memset(fp16_input_, 0, fp16_input_size); + + /*=============================fp16_out_============================*/ + size_t fp16_output_size = conv_param_->output_channel_ * conv_param_->output_batch_ * conv_param_->output_h_ * + conv_param_->output_w_ * sizeof(float16_t); + fp16_out_ = reinterpret_cast(malloc(fp16_output_size)); + if (fp16_out_ == nullptr) { + MS_LOG(ERROR) << "malloc fp16_out_ failed."; + return RET_ERROR; + } + return RET_OK; +} + int Convolution1x1FP16CPUKernel::Init() { auto ret = ConvolutionBaseCPUKernel::Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvolutionBase init failed."; return ret; } + ret = InitMatmulParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init matmul param failed."; + return ret; + } + ret = InitConv1x1Param(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init conv1x1 param failed."; + return ret; + } + ret = InitBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init buffer failed."; + return ret; + } + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return ret; + } return RET_OK; } @@ -47,8 +156,14 @@ int Convolution1x1FP16CPUKernel::ReSize() { if (fp16_input_ != nullptr) { free(fp16_input_); } - if (nhwc4_input_ != nullptr) { - free(nhwc4_input_); + if (fp16_weight_ != nullptr) { + free(fp16_weight_); + } + if (input_ptr_ != nullptr) { + free(input_ptr_); + } + if (weight_ptr_ != nullptr) { + free(weight_ptr_); } auto ret = ConvolutionBaseCPUKernel::Init(); @@ -56,13 +171,49 @@ int Convolution1x1FP16CPUKernel::ReSize() { MS_LOG(ERROR) << "ConvolutionBase init failed."; return ret; } + ret = InitMatmulParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init matmul param failed."; + return ret; + } + ret = InitConv1x1Param(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init conv1x1 param failed."; + return ret; + } + ret = InitBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init buffer failed."; + return ret; + } + return RET_OK; } +void Convolution1x1FP16CPUKernel::Pre1x1Trans(float16_t *src_input, float16_t *src_output) { + output_ptr_ = src_output; + if (pre_trans_input_) { + Conv1x1InputPackFp16(src_input, input_ptr_, conv_param_); + } else { + input_ptr_ = src_input; + } + + RowMajor2Col8MajorFp16(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); + return; +} + int Convolution1x1FP16CPUKernel::RunImpl(int task_id) { - // Conv1x1Fp16(reinterpret_cast(nhwc4_input_), transformed_filter_addr_, - // reinterpret_cast(bias_data_), fp16_out_, tile_buffer_, block_unit_buffer_, - // tmp_dst_buffer_, tmp_out_, task_id, conv_param_); + int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_); + if (cur_oc <= 0) { + return RET_OK; + } + + auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast(bias_data_) + thread_stride_ * task_id; + + MatMulFp16(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_, + output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_, + matmul_param_->row_, cur_oc, matmul_param_->col_, true); + return RET_OK; } @@ -83,12 +234,22 @@ int Convolution1x1FP16CPUKernel::Run() { return RET_ERROR; } - ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); + ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get executor tensor failed."; + return ret; + } + + for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { + Pre1x1Trans( + execute_input_ + batch_index * conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_, + execute_output_ + batch_index * matmul_param_->row_ * matmul_param_->col_); - int error_code = LiteBackendParallelLaunch(Convolution1x1Fp16Impl, this, thread_count_); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "conv1x1 fp16 error error_code[" << error_code << "]"; - return RET_ERROR; + int error_code = LiteBackendParallelLaunch(Convolution1x1Fp16Impl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv1x1 fp16 error error_code[" << error_code << "]"; + return RET_ERROR; + } } ConvolutionBaseFP16CPUKernel::IfCastOutput(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h index 989e1e2bc7..d88e79755f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h @@ -22,6 +22,8 @@ #include "src/lite_kernel.h" #include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h" #include "src/runtime/kernel/arm/nnacl/optimized_kernel.h" +#include "src/runtime/kernel/arm/nnacl/matmul_parameter.h" +#include "src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h" namespace mindspore::kernel { class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel { @@ -29,7 +31,9 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel { Convolution1x1FP16CPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const Context *ctx, const lite::Primitive *primitive) - : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {} + : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) { + matmul_param_ = new MatMulParameter(); + } ~Convolution1x1FP16CPUKernel() override { if (fp16_input_ != nullptr) { free(fp16_input_); @@ -40,14 +44,34 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel { if (fp16_out_ != nullptr) { free(fp16_out_); } + if (input_ptr_ != nullptr) { + free(input_ptr_); + } + if (weight_ptr_ != nullptr) { + free(weight_ptr_); + } + delete matmul_param_; } int Init() override; int ReSize() override; int Run() override; int RunImpl(int task_id); + int InitBuffer(); + int InitConv1x1Param(); + int InitMatmulParam(); + int InitWeightBias(); + void Pre1x1Trans(float16_t *src_input, float16_t *src_output); private: + bool pre_trans_input_ = false; + int thread_count_ = 0; + int thread_stride_ = 0; + float16_t *weight_ptr_ = nullptr; + float16_t *input_ptr_ = nullptr; + float16_t *pack_input_ = nullptr; + float16_t *output_ptr_ = nullptr; + MatMulParameter *matmul_param_ = nullptr; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc index 46c443fb0a..fa129b3c81 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc @@ -62,7 +62,11 @@ int Convolution3x3FP16CPUKernel::InitWeightBias() { return RET_ERROR; } memset(transformed_filter_addr_, 0, transformed_size); - ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); + auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get Execute filter failed."; + return ret; + } ProcessFilterFp16(execute_weight_, transformed_filter_addr_, conv_param_); // init bias @@ -249,8 +253,11 @@ int Convolution3x3FP16CPUKernel::Run() { MS_LOG(ERROR) << "Prepare failed."; return RET_ERROR; } - ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); - + ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get execute tensor failed."; + return ret; + } int in_batch = conv_param_->input_batch_; int in_h = conv_param_->input_h_; int in_w = conv_param_->input_w_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index 235b7e48ff..7cb95f8165 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -18,6 +18,7 @@ #include #include "src/runtime/kernel/arm/fp16/convolution_sw_fp16.h" #include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" +#include "src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h" #include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h" #include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" #include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" @@ -46,7 +47,11 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane; // init weight - ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); + auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get Execute filter failed."; + return ret; + } packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float16_t))); if (packed_weight_ == nullptr) { MS_LOG(ERROR) << "malloc packed_weight_ failed."; @@ -218,7 +223,12 @@ int ConvolutionFP16CPUKernel::Run() { MS_LOG(ERROR) << "Prepare failed."; return RET_ERROR; } - ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); + + ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get Execute tensor failed."; + return ret; + } int in_batch = conv_param_->input_batch_; int in_h = conv_param_->input_h_; @@ -256,6 +266,8 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vectoroutput_channel_; int ic4 = UP_DIV(in_channel, C4NUM); - ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); + auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get Execute filter failed."; + return ret; + } for (int oc = 0; oc < out_channel; ++oc) { int src_oc_offset = oc * kernel_h * kernel_w * in_channel; @@ -228,7 +232,11 @@ int ConvolutionSWFP16CPUKernel::Run() { MS_LOG(ERROR) << "Prepare failed."; return RET_ERROR; } - ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); + ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get Execute tensor failed."; + return ret; + } int in_batch = conv_param_->input_batch_; int in_h = conv_param_->input_h_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc index 070f5bed6c..a3866c14e3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc @@ -115,7 +115,11 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { return RET_ERROR; } - ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); + ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get Execute filter failed."; + return ret; + } WinogradFilterTransformFp16(execute_weight_, trans_weight_, kernel_unit_, input_unit_, conv_param_, oc_block); // init bias @@ -377,7 +381,11 @@ int ConvolutionWinogradFP16CPUKernel::Run() { return prepare_ret; } - ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); + auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get Execute tensor failed."; + return ret; + } int in_batch = conv_param_->input_batch_; int in_h = conv_param_->input_h_; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.c index 82428482f3..460bdaf7b1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.c @@ -16,6 +16,11 @@ #include "nnacl/fp16/matmul_fp16.h" +void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, + int depth, int row, int col, int stride, bool write_nhwc) { + MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc); +} + void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { size_t row16 = row / C16NUM * C16NUM; size_t col8 = col / C8NUM * C8NUM; @@ -134,7 +139,7 @@ void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, : : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", - "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); #else for (int tr = 0; tr < C16NUM; tr++) { diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h index c9a4981caf..8156a11a1a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h @@ -29,10 +29,13 @@ #ifdef __cplusplus extern "C" { #endif +void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, + int depth, int row, int col, int stride, bool write_nhwc); + void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); #ifdef __aarch64__ void MatmulFp16Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, - int col); + int col, int stride, bool write_nhwc); #endif #ifdef __cplusplus } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.c index 2e3532aaf4..dd84dcef22 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.c @@ -18,6 +18,27 @@ #include #include +void Conv1x1InputPackFp16(const float16_t *src, float16_t *dst, ConvParameter *conv_param) { + /* support nhwc */ + for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) { + int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_h_; + if (src_h < 0 || src_h >= conv_param->input_h_) { + continue; + } + const float16_t *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_; + float16_t *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_; + for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) { + int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_w_; + if (src_w < 0 || src_w >= conv_param->input_w_) { + continue; + } + memcpy(dst_h_ptr + dst_w * conv_param->input_channel_, src_h_ptr + src_w * conv_param->input_channel_, + conv_param->input_channel_ * sizeof(float16_t)); + } + } + return; +} + void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num, int block_index) { // input format : nhwc diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h index 188d9a0465..40d95be422 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h @@ -26,6 +26,8 @@ #ifdef __cplusplus extern "C" { #endif +void Conv1x1InputPackFp16(const float16_t *src, float16_t *dst, ConvParameter *conv_param); + void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num, int block_index); diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h index 4bed0f0e07..be01e4beb2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h @@ -26,6 +26,7 @@ typedef struct MatMulParameter { int row_; int col_; int row_8_; + int row_16_; int col_8_; int deep_; bool has_bias_;