diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc index e0e00f5bcf..764a22eb51 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc @@ -92,9 +92,8 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() { return RET_ERROR; } memset(weight_ptr_, 0, matmul_param_->deep_ * matmul_param_->col_8_ * sizeof(float16_t)); - RowMajor2Col8MajorFp16(reinterpret_cast(execute_weight_), weight_ptr_, matmul_param_->col_, - matmul_param_->deep_); - + ColMajor2Row8MajorFp16(reinterpret_cast(execute_weight_), weight_ptr_, matmul_param_->deep_, + matmul_param_->col_); return RET_OK; } @@ -159,7 +158,7 @@ void Convolution1x1FP16CPUKernel::Pre1x1Trans(float16_t *src_input, float16_t *s input_ptr_ = src_input; } - RowMajor2Col8MajorFp16(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); + RowMajor2Col16MajorFp16(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); return; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h index 919fe41f36..a37f6573a1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h @@ -35,15 +35,12 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel { matmul_param_ = new MatMulParameter(); } ~Convolution1x1FP16CPUKernel() override { - if (fp16_weight_ != nullptr) { - free(fp16_weight_); - } - if (input_ptr_ != nullptr) { - free(input_ptr_); - } if (weight_ptr_ != nullptr) { free(weight_ptr_); } + if (pack_input_ != nullptr) { + free(pack_input_); + } delete matmul_param_; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc index b111624431..434b0161b5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc @@ -255,7 +255,7 @@ int Convolution3x3FP16CPUKernel::Run() { bool relu6 = conv_param_->is_relu6_; for (int batch = 0; batch < conv_param_->output_batch_; batch++) { int tmp_out_batch_offset = - batch * oc8 * C8NUM * out_w_block * out_h_block * conv_param_->output_unit_ * conv_param_->output_unit_; + batch * oc8 * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM; int ro_batch_size = batch * conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_; const float16_t *batch_tmp_out = tmp_out_ + tmp_out_batch_offset; float16_t *batch_out = execute_output_ + ro_batch_size; @@ -265,7 +265,7 @@ int Convolution3x3FP16CPUKernel::Run() { int oc8_block = c / C8NUM; int oc8_res = c % C8NUM; int src_offset = oc8_block * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM + - C8NUM * (h * out_w_block * conv_param_->output_unit_ + w) + oc8_res; + C8NUM * (h * out_w_block * C4NUM + w) + oc8_res; int dst_offset = (h * conv_param_->output_w_ + w) * conv_param_->output_channel_ + c; (batch_out + dst_offset)[0] = (batch_tmp_out + src_offset)[0]; if (relu) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_base_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_base_fp16.cc index 5d44fbb003..55bfa4d583 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_base_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_base_fp16.cc @@ -47,7 +47,7 @@ int ConvolutionBaseFP16CPUKernel::GetExecuteFilter() { if (weight_data_type == kNumberTypeFloat32) { float *origin_weight = reinterpret_cast(in_tensors_.at(kWeightIndex)->Data()); size_t fp16_weight_size = conv_param_->input_channel_ * conv_param_->output_channel_ * conv_param_->kernel_h_ * - conv_param_->input_w_ * sizeof(float16_t); + conv_param_->kernel_w_ * sizeof(float16_t); fp16_weight_ = reinterpret_cast(malloc(fp16_weight_size)); if (fp16_weight_ == nullptr) { MS_LOG(ERROR) << "malloc fp16_weight_ failed."; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index a3ea280767..5e8f121f30 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -219,8 +219,8 @@ int ConvolutionFP16CPUKernel::Run() { return RET_ERROR; } - ConvolutionBaseFP16CPUKernel::FreeTmpBuffer(); ConvolutionBaseFP16CPUKernel::IfCastOutput(); + ConvolutionBaseFP16CPUKernel::FreeTmpBuffer(); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc index 6732293016..7f29a101f0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc @@ -170,7 +170,7 @@ int DeConvolutionFp16CPUKernel::Run() { ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { - RowMajor2Col8MajorFp16(execute_input_, pack_input_, input_plane_, conv_param_->input_channel_); + RowMajor2Col16MajorFp16(execute_input_, pack_input_, input_plane_, conv_param_->input_channel_); int error_code = LiteBackendParallelLaunch(DeConvFp16Run, this, thread_count_); if (error_code != RET_OK) { diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.c index bbb4360beb..b59cba408d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.c @@ -15,27 +15,57 @@ */ #include "nnacl/fp16/matmul_fp16.h" +void ColMajor2Row8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int cd8 = c / 8; + int cm8 = c % 8; + dst_ptr[cd8 * 8 * row + r * 8 + cm8] = src_ptr[c * row + r]; + } + } +} + void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, int deep, int row, int col, int stride, bool write_nhwc) { int row_16 = UP_ROUND(row, C16NUM); int col_8 = UP_ROUND(col, C8NUM); - /* col16-major * row8-major => row16x8-major */ - if (write_nhwc) return; - for (int r = 0; r < row_16; r++) { - for (int c = 0; c < col_8; c++) { - int r16div = r / C16NUM, r16mod = r % C16NUM; - int c8div = c / C8NUM, c8mod = c % C8NUM; - size_t ci = c8div * row_16 * C8NUM + r * C8NUM + c8mod; - float16_t value = 0; - for (int d = 0; d < deep; d++) { - size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; - size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; - value = value + a[ai] * b[bi]; + if (write_nhwc) { + /* col16-major * row8-major => col-major */ + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r16div = r / C16NUM, r16mod = r % C16NUM; + int c8div = c / C8NUM, c8mod = c % C8NUM; + size_t ci = r * stride + c; + float value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; + size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; + value = value + a[ai] * b[bi]; + } + if (bias != NULL) value += bias[c]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type != ActType_No) value = MSMAX(0.0f, value); + dst[ci] = value; + } + } + } else { + /* col16-major * row8-major => row16x8-major */ + for (int r = 0; r < row_16; r++) { + for (int c = 0; c < col_8; c++) { + int r16div = r / C16NUM, r16mod = r % C16NUM; + int c8div = c / C8NUM, c8mod = c % C8NUM; + size_t ci = c8div * row_16 * C8NUM + r * C8NUM + c8mod; + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; + size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; + value = value + a[ai] * b[bi]; + } + if (bias != NULL) value += bias[col]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type != ActType_No) value = MSMAX(0.0f, value); + dst[ci] = value; } - if (bias != NULL) value += bias[col]; - if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); - if (act_type != ActType_No) value = MSMAX(0.0f, value); - dst[ci] = value; } } return; @@ -43,12 +73,12 @@ void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const fl void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, int depth, int row, int col, int stride, bool write_nhwc) { - // MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc); - MatMul16x8(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc); + MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc); + // MatMul16x8(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc); return; } -void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { +void RowMajor2Col16MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { size_t row16 = row / C16NUM * C16NUM; size_t col8 = col / C8NUM * C8NUM; float16_t *src_r = src_ptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h index 237d6cb141..e3d7e8b90e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h @@ -32,7 +32,9 @@ extern "C" { void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, int depth, int row, int col, int stride, bool write_nhwc); -void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); +void ColMajor2Row8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); + +void RowMajor2Col16MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc);