| @@ -47,7 +47,7 @@ | |||||
| ///////////////////////////////////////////////////////////////////////////////// | ///////////////////////////////////////////////////////////////////////////////// | ||||
| // | // | ||||
| // void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth | // void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth | ||||
| // int row, int col, int stride, bool write_nhwc) | |||||
| // int row, int col, size_t stride, size_t writeNhwc, size_t writeC4) | |||||
| // x0: a | // x0: a | ||||
| // x1: b | // x1: b | ||||
| // x2: c | // x2: c | ||||
| @@ -57,7 +57,7 @@ | |||||
| // w6: row | // w6: row | ||||
| // w7: col | // w7: col | ||||
| // w17: stride | // w17: stride | ||||
| // w13: writeC8 | |||||
| // w13: c8_nhwc_c4 | |||||
| MatmulFloatNeon64Opt: | MatmulFloatNeon64Opt: | ||||
| sub sp, sp, #128 | sub sp, sp, #128 | ||||
| @@ -209,8 +209,8 @@ Activation: | |||||
| b Write | b Write | ||||
| Relu6: | Relu6: | ||||
| mov w8, #6 | |||||
| dup v2.4s, w8 | |||||
| mov w13, #6 | |||||
| dup v2.4s, w13 | |||||
| scvtf v2.4s, v2.4s | scvtf v2.4s, v2.4s | ||||
| fmin v8.4s, v8.4s, v2.4s | fmin v8.4s, v8.4s, v2.4s | ||||
| fmin v9.4s, v9.4s, v2.4s | fmin v9.4s, v9.4s, v2.4s | ||||
| @@ -265,8 +265,10 @@ Relu: | |||||
| fmax v31.4s, v31.4s, v3.4s | fmax v31.4s, v31.4s, v3.4s | ||||
| Write: | Write: | ||||
| ldrb w13, [sp, #8] | |||||
| cbz w13, WriteC8 | |||||
| ldr w8, [sp, #8] | |||||
| cbz w8, WriteC8 | |||||
| ldr w8, [sp, #16] | |||||
| cbnz w8, WriteC4 | |||||
| cmp w7, #1 | cmp w7, #1 | ||||
| beq Write1 | beq Write1 | ||||
| cmp w7, #2 | cmp w7, #2 | ||||
| @@ -726,6 +728,33 @@ WriteC8: | |||||
| st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x2], #64 | st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x2], #64 | ||||
| st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x2], #64 | st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x2], #64 | ||||
| b WriteEnd | b WriteEnd | ||||
| WriteC4: | |||||
| st1 {v8.8h}, [x2], #16 | |||||
| st1 {v10.8h}, [x2], #16 | |||||
| st1 {v12.8h}, [x2], #16 | |||||
| st1 {v14.8h}, [x2], #16 | |||||
| st1 {v16.8h}, [x2], #16 | |||||
| st1 {v18.8h}, [x2], #16 | |||||
| st1 {v20.8h}, [x2], #16 | |||||
| st1 {v22.8h}, [x2], #16 | |||||
| st1 {v24.8h}, [x2], #16 | |||||
| st1 {v26.8h}, [x2], #16 | |||||
| st1 {v28.8h}, [x2], #16 | |||||
| st1 {v30.8h}, [x2], #16 | |||||
| add x18, x2, x17 | |||||
| st1 {v9.8h}, [x18], #16 | |||||
| st1 {v11.8h}, [x18], #16 | |||||
| st1 {v13.8h}, [x18], #16 | |||||
| st1 {v15.8h}, [x18], #16 | |||||
| st1 {v17.8h}, [x18], #16 | |||||
| st1 {v19.8h}, [x18], #16 | |||||
| st1 {v21.8h}, [x18], #16 | |||||
| st1 {v23.8h}, [x18], #16 | |||||
| st1 {v25.8h}, [x18], #16 | |||||
| st1 {v27.8h}, [x18], #16 | |||||
| st1 {v29.8h}, [x18], #16 | |||||
| st1 {v31.8h}, [x18], #16 | |||||
| b WriteEnd | |||||
| Write8: | Write8: | ||||
| st1 {v8.4s, v9.4s}, [x18], x17 | st1 {v8.4s, v9.4s}, [x18], x17 | ||||
| cmp w10, #1 | cmp w10, #1 | ||||
| @@ -770,9 +799,14 @@ End2: | |||||
| subs w7, w7, #8 // rhs col - 8 | subs w7, w7, #8 // rhs col - 8 | ||||
| add x1, x1, x15 // rhs ptr + stride | add x1, x1, x15 // rhs ptr + stride | ||||
| add x3, x3, #32 // bias ptr + stride | add x3, x3, #32 // bias ptr + stride | ||||
| ldrb w13, [sp, #8] | |||||
| cbz w13, NoDstStep | |||||
| ldr w8, [sp, #8] | |||||
| cbz w8, NoDstStep | |||||
| ldr w8, [sp, #16] | |||||
| cbnz w8, C4DstStep | |||||
| add x2, x2, #32 // dst ptr + stride | add x2, x2, #32 // dst ptr + stride | ||||
| b NoDstStep | |||||
| C4DstStep: | |||||
| add x2, x2, x17 | |||||
| NoDstStep: | NoDstStep: | ||||
| bgt L1 | bgt L1 | ||||
| @@ -370,8 +370,8 @@ void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, Ac | |||||
| } | } | ||||
| void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, | void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, | ||||
| int col, int stride, bool write_nhwc) { | |||||
| if (write_nhwc) { | |||||
| int col, size_t stride, size_t writeNhwc, size_t writeC4) { | |||||
| if (writeNhwc != 0) { | |||||
| /* col8-major * row8-major => col-major */ | /* col8-major * row8-major => col-major */ | ||||
| for (int r = 0; r < row; r++) { | for (int r = 0; r < row; r++) { | ||||
| for (int c = 0; c < col; c++) { | for (int c = 0; c < col; c++) { | ||||
| @@ -404,10 +404,10 @@ void MatMul(const float *a, const float *b, float *c, const float *bias, ActType | |||||
| } | } | ||||
| void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, | void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, | ||||
| int col, int stride, bool write_nhwc) { | |||||
| int col, size_t stride, size_t writeNhwc, size_t writeC4) { | |||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc); | |||||
| MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, writeNhwc, writeC4); | |||||
| #else | #else | ||||
| MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc); | |||||
| MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, writeNhwc, writeC4); | |||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -29,7 +29,7 @@ extern "C" { | |||||
| void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col, | void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col, | ||||
| int stride, bool write_nhwc); | int stride, bool write_nhwc); | ||||
| void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, | void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, | ||||
| int col, int stride, bool write_nhwc); | |||||
| int col, size_t stride, size_t writeNhwc, size_t writeC4); | |||||
| void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col); | void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col); | ||||
| void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); | void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); | ||||
| void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); | void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); | ||||
| @@ -38,7 +38,7 @@ void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col | |||||
| void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | ||||
| int col, size_t stride, bool write_nhwc); | int col, size_t stride, bool write_nhwc); | ||||
| void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | ||||
| int col, size_t stride, bool write_nhwc); | |||||
| int col, size_t stride, size_t writeNhwc, size_t writeC4); | |||||
| #endif | #endif | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| @@ -145,7 +145,7 @@ int Convolution1x1CPUKernel::DoConv1x1(int task_id) { | |||||
| MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_, | MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_, | ||||
| output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_, | output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_, | ||||
| matmul_param_->row_, cur_oc, matmul_param_->col_, true); | |||||
| matmul_param_->row_, cur_oc, matmul_param_->col_, 1, 0); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||