| @@ -28,8 +28,8 @@ | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::lite::RET_INFER_INVALID; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Conv2D; | |||
| namespace mindspore::kernel { | |||
| @@ -254,8 +254,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten | |||
| CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func); | |||
| kernel::LiteKernel *kernel; | |||
| if (kernel_h == 1 && kernel_w == 1) { | |||
| // kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx); | |||
| kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { | |||
| kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } else if (use_winograd) { | |||
| @@ -171,6 +171,17 @@ int Convolution1x1CPUKernel::DoConv1x1(int task_id) { | |||
| return RET_OK; | |||
| } | |||
| int Convolution1x1CPUKernel::DoConv1x1Post(int task_id) { | |||
| int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_); | |||
| if (cur_oc <= 0) { | |||
| return RET_OK; | |||
| } | |||
| float *src = pack_output_ + task_id * thread_stride_ * matmul_param_->row_8_; | |||
| float *dst = output_ptr_ + task_id * thread_stride_; | |||
| Row8x8Major2RowMajor(src, dst, matmul_param_->row_, cur_oc, matmul_param_->col_); | |||
| return RET_OK; | |||
| } | |||
| int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| auto conv1x1 = reinterpret_cast<Convolution1x1CPUKernel *>(cdata); | |||
| auto error_code = conv1x1->DoConv1x1(task_id); | |||
| @@ -181,6 +192,12 @@ int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| return RET_OK; | |||
| } | |||
| int Convolution1x1Post(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| auto conv1x1 = reinterpret_cast<Convolution1x1CPUKernel *>(cdata); | |||
| conv1x1->DoConv1x1Post(task_id); | |||
| return RET_OK; | |||
| } | |||
| int Convolution1x1CPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| @@ -200,7 +217,7 @@ int Convolution1x1CPUKernel::Run() { | |||
| return RET_ERROR; | |||
| } | |||
| Row8x8Major2RowMajor(pack_output_, output_ptr_, matmul_param_->row_, matmul_param_->col_); | |||
| LiteBackendParallelLaunch(Convolution1x1Post, this, thread_count_); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -46,6 +46,7 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { | |||
| public: | |||
| int DoConv1x1(int task_id); | |||
| int DoConv1x1Post(int task_id); | |||
| private: | |||
| int InitConv1x1Param(); | |||
| @@ -121,7 +121,7 @@ int FullconnectionCPUKernel::Run() { | |||
| LiteBackendParallelLaunch(FcFp32MatmulRun, this, thread_count_); | |||
| Row8x8Major2RowMajor(c_r8x8_ptr_, output_ptr, fc_param_->row_, fc_param_->col_); | |||
| Row8x8Major2RowMajor(c_r8x8_ptr_, output_ptr, fc_param_->row_, fc_param_->col_, fc_param_->col_); | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -118,7 +118,7 @@ int MatmulCPUKernel::Run() { | |||
| RowMajor2Row8Major(cur_b_ptr, b_r8_ptr_, params_->deep_, params_->col_); | |||
| } | |||
| LiteBackendParallelLaunch(MatmulFloatRun, this, thread_count_); | |||
| Row8x8Major2RowMajor(c_r8x8_ptr_, cur_c_ptr, params_->row_, params_->col_); | |||
| Row8x8Major2RowMajor(c_r8x8_ptr_, cur_c_ptr, params_->row_, params_->col_, params_->col_); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -119,14 +119,104 @@ void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) | |||
| return; | |||
| } | |||
| void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, int row, int col) { | |||
| int row8 = UP_ROUND(row, 8); | |||
| for (int c = 0; c < col; c++) { | |||
| int cd8 = c / 8; | |||
| int cm8 = c % 8; | |||
| for (int r = 0; r < row; r++) { | |||
| dst_ptr[r * col + c] = src_ptr[cd8 * row8 * 8 + r * 8 + cm8]; | |||
| inline void MatrixUnPackUnit(const void *src, void *dst, size_t row, size_t col, size_t src_stride, size_t dst_stride, | |||
| size_t data_lenth) { | |||
| size_t copy_size = col * data_lenth; | |||
| size_t src_size = src_stride * data_lenth; | |||
| size_t dst_size = dst_stride * data_lenth; | |||
| char *src_ptr = (char *)src; | |||
| char *dst_ptr = (char *)dst; | |||
| for (int r = 0; r < row; r++) { | |||
| memcpy(dst_ptr, src_ptr, copy_size); | |||
| src_ptr += src_size; | |||
| dst_ptr += dst_size; | |||
| } | |||
| } | |||
| void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride) { | |||
| size_t row_up8 = UP_ROUND(row, C8NUM); | |||
| size_t row_8div = row / C8NUM * C8NUM; | |||
| size_t row_8res = row - row_8div; | |||
| size_t col_8div = col / C8NUM * C8NUM; | |||
| size_t col_8res = col - col_8div; | |||
| float *src_c = src_ptr; | |||
| float *dst_c = dst_ptr; | |||
| for (size_t ci = 0; ci < col_8div; ci += C8NUM) { | |||
| #ifdef ENABLE_ARM64 | |||
| size_t offset = stride * 4 - 16; | |||
| asm volatile( | |||
| "mov x0, #0 \n" | |||
| "mov x1, %[row_8div] \n" | |||
| "mov x10, %[src_c] \n" | |||
| "mov x11, %[dst_c] \n" | |||
| "Loop8x8: \n" | |||
| "cmp x0, x1 \n" | |||
| "beq End \n" | |||
| "ld1 {v0.4s}, [x10], #16\n" | |||
| "ld1 {v1.4s}, [x10], #16\n" | |||
| "ld1 {v2.4s}, [x10], #16\n" | |||
| "ld1 {v3.4s}, [x10], #16\n" | |||
| "ld1 {v4.4s}, [x10], #16\n" | |||
| "ld1 {v5.4s}, [x10], #16\n" | |||
| "ld1 {v6.4s}, [x10], #16\n" | |||
| "ld1 {v7.4s}, [x10], #16\n" | |||
| "ld1 {v8.4s}, [x10], #16\n" | |||
| "ld1 {v9.4s}, [x10], #16\n" | |||
| "ld1 {v10.4s}, [x10], #16\n" | |||
| "ld1 {v11.4s}, [x10], #16\n" | |||
| "ld1 {v12.4s}, [x10], #16\n" | |||
| "ld1 {v13.4s}, [x10], #16\n" | |||
| "ld1 {v14.4s}, [x10], #16\n" | |||
| "ld1 {v15.4s}, [x10], #16\n" | |||
| "add x0, x0, #8\n" | |||
| "st1 {v0.4s}, [x11], #16\n" | |||
| "st1 {v1.4s}, [x11], %[offset]\n" | |||
| "st1 {v2.4s}, [x11], #16\n" | |||
| "st1 {v3.4s}, [x11], %[offset]\n" | |||
| "st1 {v4.4s}, [x11], #16\n" | |||
| "st1 {v5.4s}, [x11], %[offset]\n" | |||
| "st1 {v6.4s}, [x11], #16\n" | |||
| "st1 {v7.4s}, [x11], %[offset]\n" | |||
| "st1 {v8.4s}, [x11], #16\n" | |||
| "st1 {v9.4s}, [x11], %[offset]\n" | |||
| "st1 {v10.4s}, [x11], #16\n" | |||
| "st1 {v11.4s}, [x11], %[offset]\n" | |||
| "st1 {v12.4s}, [x11], #16\n" | |||
| "st1 {v13.4s}, [x11], %[offset]\n" | |||
| "st1 {v14.4s}, [x11], #16\n" | |||
| "st1 {v15.4s}, [x11], %[offset]\n" | |||
| "b Loop8x8\n" | |||
| "End:\n" | |||
| : | |||
| : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ offset ] "r"(offset), [ row_8div ] "r"(row_8div) | |||
| : "x0", "x1", "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||
| "v13", "v14", "v15"); | |||
| #else | |||
| for (size_t ri = 0; ri < row_8div; ri += C8NUM) { | |||
| float *src_r = src_c + ri * C8NUM; | |||
| float *dst_r = dst_c + ri * stride; | |||
| MatrixUnPackUnit(src_r, dst_r, C8NUM, C8NUM, C8NUM, stride, sizeof(float)); | |||
| } | |||
| #endif | |||
| if (row != row_8div) { | |||
| float *src_r = src_c + row_8div * C8NUM; | |||
| float *dst_r = dst_c + row_8div * stride; | |||
| MatrixUnPackUnit(src_r, dst_r, row_8res, C8NUM, C8NUM, stride, sizeof(float)); | |||
| } | |||
| src_c += row_up8 * C8NUM; | |||
| dst_c += C8NUM; | |||
| } | |||
| if (col != col_8div) { | |||
| MatrixUnPackUnit(src_c, dst_c, row, col_8res, C8NUM, stride, sizeof(float)); | |||
| } | |||
| return; | |||
| } | |||
| @@ -17,6 +17,7 @@ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_MATMUL_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_MATMUL_H_ | |||
| #include <string.h> | |||
| #include <float.h> | |||
| #include "src/runtime/kernel/arm/nnacl/errorcode.h" | |||
| #include "src/runtime/kernel/arm/nnacl/op_base.h" | |||
| @@ -25,7 +26,7 @@ | |||
| void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col); | |||
| void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col); | |||
| void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); | |||
| void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, int row, int col); | |||
| void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride); | |||
| void MatMul8x8(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int deep, | |||
| int row_8_, int col_8_); | |||
| #ifdef __cplusplus | |||
| @@ -88,7 +88,7 @@ TEST_F(TestMatMulFp32, Row8x82RowTest1) { | |||
| 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, | |||
| 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52}; | |||
| float out[90] = {0}; | |||
| Row8x8Major2RowMajor(in, out, 18, 5); | |||
| Row8x8Major2RowMajor(in, out, 18, 5, 5); | |||
| CompareOutputData(out, co, 90, 0.0001); | |||
| } | |||
| @@ -100,7 +100,7 @@ TEST_F(TestMatMulFp32, Row8x82RowTest2) { | |||
| float co[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, | |||
| 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39}; | |||
| float out[30] = {0}; | |||
| Row8x8Major2RowMajor(in, out, 6, 5); | |||
| Row8x8Major2RowMajor(in, out, 6, 5, 5); | |||
| CompareOutputData(out, co, 30, 0.0001); | |||
| } | |||
| @@ -161,10 +161,24 @@ TEST_F(TestMatMulFp32, Row8x82RowTest3) { | |||
| 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, | |||
| 0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53}; | |||
| float out[418] = {0}; | |||
| Row8x8Major2RowMajor(in, out, 22, 19); | |||
| Row8x8Major2RowMajor(in, out, 22, 19, 19); | |||
| CompareOutputData(out, co, 418, 0.0001); | |||
| } | |||
| TEST_F(TestMatMulFp32, Row8x82RowTest4) { | |||
| float in[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.27, | |||
| 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, | |||
| 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, | |||
| 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.27, 0.39}; | |||
| float co[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.27, | |||
| 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, | |||
| 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, | |||
| 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.27, 0.39}; | |||
| float out[64] = {0}; | |||
| Row8x8Major2RowMajor(in, out, 8, 8, 8); | |||
| CompareOutputData(out, co, 64, 0.0001); | |||
| } | |||
| int MMTestInit(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_, | |||
| float *a_ptr, float *b_ptr, std::vector<int> a_shape, std::vector<int> b_shape, | |||
| std::vector<int> c_shape) { | |||
| @@ -160,7 +160,7 @@ TEST_F(TestDeconvInt8, MatMulTest1) { | |||
| RowMajor2Col8MajorInt8(a_row_major_10_12, a_col8_major, 10, 12); | |||
| RowMajor2Col8MajorInt8(b_col_major_12_18, b_row8_major, 18, 12); | |||
| MatMulInt8(a_col8_major, b_row8_major, c_row8x8_major, 16, 24, 12, zp_a, zp_b); | |||
| Row8x8Major2RowMajor(reinterpret_cast<float *>(c_row8x8_major), reinterpret_cast<float *>(out_row_major), 10, 18); | |||
| Row8x8Major2RowMajor(reinterpret_cast<float *>(c_row8x8_major), reinterpret_cast<float *>(out_row_major), 10, 18, 18); | |||
| CompareOutputData(out_row_major, co_row_major_10_18, 180, 1); | |||
| } | |||