| @@ -0,0 +1,102 @@ | |||||
| #ifdef __aarch64__ | |||||
| .text | |||||
| .align 5 | |||||
| .global ConvDw3x3Corner | |||||
| #ifndef __APPLE__ | |||||
| .type ConvDw3x3Corner, %function | |||||
| #endif | |||||
| // void ConvDw3x3Corner(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, | |||||
| // int in_kw_step, int channel, size_t relu, size_t relu6) | |||||
| // x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, x6: channel, x7: relu, x8: relu6 | |||||
| ConvDw3x3Corner: | |||||
| // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to | |||||
| // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | |||||
| // x19 ~ x29 should be also preserved | |||||
| // whereas our coding style do not permit such amount of parameters | |||||
| ldr x8, [sp] | |||||
| mov x9, #4 | |||||
| mul x13, x6, x9 // x6 * 4 | |||||
| mul x4, x4, x9 | |||||
| mul x5, x5, x9 | |||||
| mov x9, #3 | |||||
| mul x14, x13, x9 // x6 * 3 * 4 | |||||
| movi v26.4s, #6 | |||||
| scvtf v26.4s, v26.4s | |||||
| dup v27.4s, wzr | |||||
| ld1 {v23.4s}, [x3], #16 | |||||
| mov x9, x1 | |||||
| mov x10, x2 | |||||
| ld1 {v0.4s}, [x9], x5 | |||||
| add x11, x1, x4 | |||||
| ld1 {v4.4s}, [x10], x13 // weight | |||||
| add x12, x2, x14 | |||||
| ld1 {v1.4s}, [x9], x5 | |||||
| ld1 {v5.4s}, [x10], x13 | |||||
| ld1 {v2.4s}, [x11], x5 | |||||
| ld1 {v6.4s}, [x12], x13 | |||||
| ld1 {v3.4s}, [x11], x5 | |||||
| ld1 {v7.4s}, [x12], x13 | |||||
| cmp x6, #4 | |||||
| ble LoopC4Post | |||||
| LoopC4: | |||||
| add x1, x1, #16 | |||||
| add x2, x2, #16 | |||||
| fmla v23.4s, v0.4s, v4.4s | |||||
| mov x9, x1 | |||||
| mov x10, x2 | |||||
| ld1 {v0.4s}, [x9], x5 | |||||
| ld1 {v4.4s}, [x10], x13 | |||||
| add x11, x1, x4 | |||||
| fmla v23.4s, v1.4s, v5.4s | |||||
| add x12, x2, x14 | |||||
| ld1 {v1.4s}, [x9], x5 | |||||
| fmla v23.4s, v2.4s, v6.4s | |||||
| ld1 {v5.4s}, [x10], x13 | |||||
| ld1 {v2.4s}, [x11], x5 | |||||
| fmla v23.4s, v3.4s, v7.4s | |||||
| ld1 {v6.4s}, [x12], x13 | |||||
| ld1 {v3.4s}, [x11], x5 | |||||
| ld1 {v7.4s}, [x12], x13 | |||||
| cbnz x8, C4_RELU6 | |||||
| cbnz x7, C4_RELU | |||||
| b C4_WRITE | |||||
| C4_RELU6: | |||||
| fmin v23.4s, v23.4s, v26.4s | |||||
| C4_RELU: | |||||
| fmax v23.4s, v23.4s, v27.4s | |||||
| C4_WRITE: | |||||
| st1 {v23.4s}, [x0], #16 | |||||
| ld1 {v23.4s}, [x3], #16 | |||||
| sub x6, x6, #4 | |||||
| cmp x6, #4 | |||||
| bgt LoopC4 | |||||
| LoopC4Post: | |||||
| fmla v23.4s, v0.4s, v4.4s | |||||
| fmla v23.4s, v1.4s, v5.4s | |||||
| fmla v23.4s, v2.4s, v6.4s | |||||
| fmla v23.4s, v3.4s, v7.4s | |||||
| cbnz x8, RELU6 | |||||
| cbnz x7, RELU | |||||
| b WRITE | |||||
| RELU6: | |||||
| fmin v23.4s, v23.4s, v26.4s | |||||
| RELU: | |||||
| fmax v23.4s, v23.4s, v27.4s | |||||
| WRITE: | |||||
| st1 {v23.4s}, [x0], #16 | |||||
| ret | |||||
| #endif | |||||
| @@ -0,0 +1,118 @@ | |||||
| #ifdef __aarch64__ | |||||
| .text | |||||
| .align 5 | |||||
| .global ConvDw3x3Horizontal | |||||
| #ifndef __APPLE__ | |||||
| .type ConvDw3x3Horizontal, %function | |||||
| #endif | |||||
| // void ConvDw3x3Horizontal(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, | |||||
| // int in_kw_step, int channel, size_t relu, size_t relu6) | |||||
| // x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, x6: channel, x7: relu, x8: relu6 | |||||
| ConvDw3x3Horizontal: | |||||
| // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to | |||||
| // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | |||||
| // x19 ~ x29 should be also preserved | |||||
| // whereas our coding style do not permit such amount of parameters | |||||
| ldr x8, [sp] | |||||
| mov x9, #4 | |||||
| mul x13, x6, x9 // x6 * 4 | |||||
| mul x4, x4, x9 | |||||
| mul x5, x5, x9 | |||||
| mov x9, #3 | |||||
| mul x14, x13, x9 // x6 * 3 * 4 | |||||
| movi v26.4s, #6 | |||||
| scvtf v26.4s, v26.4s | |||||
| dup v27.4s, wzr | |||||
| ld1 {v23.4s}, [x3], #16 | |||||
| mov x9, x1 | |||||
| mov x10, x2 | |||||
| ld1 {v0.4s}, [x9], x5 | |||||
| add x11, x1, x4 | |||||
| ld1 {v4.4s}, [x10], x13 | |||||
| add x12, x2, x14 | |||||
| ld1 {v1.4s}, [x9], x5 | |||||
| ld1 {v5.4s}, [x10], x13 | |||||
| add x15, x11, x4 | |||||
| ld1 {v2.4s}, [x11], x5 | |||||
| add x16, x12, x14 | |||||
| ld1 {v6.4s}, [x12], x13 | |||||
| ld1 {v3.4s}, [x11], x5 | |||||
| ld1 {v7.4s}, [x12], x13 | |||||
| ld1 {v16.4s}, [x15], x5 | |||||
| ld1 {v18.4s}, [x16], x13 | |||||
| ld1 {v17.4s}, [x15], x5 | |||||
| ld1 {v19.4s}, [x16], x13 | |||||
| cmp x6, #4 | |||||
| ble LoopC4Post | |||||
| LoopC4: | |||||
| add x1, x1, #16 | |||||
| add x2, x2, #16 | |||||
| fmla v23.4s, v0.4s, v4.4s | |||||
| mov x9, x1 | |||||
| mov x10, x2 | |||||
| ld1 {v0.4s}, [x9], x5 | |||||
| ld1 {v4.4s}, [x10], x13 | |||||
| add x11, x1, x4 | |||||
| fmla v23.4s, v1.4s, v5.4s | |||||
| add x12, x2, x14 | |||||
| ld1 {v1.4s}, [x9], x5 | |||||
| fmla v23.4s, v2.4s, v6.4s | |||||
| add x15, x11, x4 | |||||
| ld1 {v5.4s}, [x10], x13 | |||||
| ld1 {v2.4s}, [x11], x5 | |||||
| fmla v23.4s, v3.4s, v7.4s | |||||
| add x16, x12, x14 | |||||
| ld1 {v6.4s}, [x12], x13 | |||||
| ld1 {v3.4s}, [x11], x5 | |||||
| fmla v23.4s, v16.4s, v18.4s | |||||
| ld1 {v7.4s}, [x12], x13 | |||||
| ld1 {v16.4s}, [x15], x5 | |||||
| fmla v23.4s, v17.4s, v19.4s | |||||
| ld1 {v18.4s}, [x16], x13 | |||||
| ld1 {v17.4s}, [x15], x5 | |||||
| ld1 {v19.4s}, [x16], x13 | |||||
| cbnz x8, C4_RELU6 | |||||
| cbnz x7, C4_RELU | |||||
| b C4_WRITE | |||||
| C4_RELU6: | |||||
| fmin v23.4s, v23.4s, v26.4s | |||||
| C4_RELU: | |||||
| fmax v23.4s, v23.4s, v27.4s | |||||
| C4_WRITE: | |||||
| st1 {v23.4s}, [x0], #16 | |||||
| ld1 {v23.4s}, [x3], #16 | |||||
| sub x6, x6, #4 | |||||
| cmp x6, #4 | |||||
| bgt LoopC4 | |||||
| LoopC4Post: | |||||
| fmla v23.4s, v0.4s, v4.4s | |||||
| fmla v23.4s, v1.4s, v5.4s | |||||
| fmla v23.4s, v2.4s, v6.4s | |||||
| fmla v23.4s, v3.4s, v7.4s | |||||
| fmla v23.4s, v16.4s, v18.4s | |||||
| fmla v23.4s, v17.4s, v19.4s | |||||
| cbnz x8, RELU6 | |||||
| cbnz x7, RELU | |||||
| b WRITE | |||||
| RELU6: | |||||
| fmin v23.4s, v23.4s, v26.4s | |||||
| RELU: | |||||
| fmax v23.4s, v23.4s, v27.4s | |||||
| WRITE: | |||||
| st1 {v23.4s}, [x0], #16 | |||||
| ret | |||||
| #endif | |||||
| @@ -0,0 +1,199 @@ | |||||
| #ifdef __aarch64__ | |||||
| .text | |||||
| .align 5 | |||||
| .global ConvDw3x3Stride1 | |||||
| #ifndef __APPLE__ | |||||
| .type ConvDw3x3Stride1, %function | |||||
| #endif | |||||
| // void ConvDw3x3Stride1(float *output, const float *buffer, const float *weight, const float *bias, int col_size, | |||||
| // int row_size, int channel, int output_h, int output_w, size_t relu, size_t relu6) | |||||
| // | |||||
| // x0: output | |||||
| // x1: input | |||||
| // x2: weight | |||||
| // x3: bias | |||||
| // w4: col_size | |||||
| // w5: row_size | |||||
| // w6: channel | |||||
| // w7: output_h | |||||
| // w8: output_w | |||||
| // w9: relu | |||||
| // w10: relu6 | |||||
| ConvDw3x3Stride1: | |||||
| sub sp, sp, #128 | |||||
| st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||||
| ldr w8, [sp] | |||||
| ldr w9, [sp, #8] | |||||
| ldr w10, [sp, #16] | |||||
| mov w11, #4 | |||||
| mul w15, w4, w11 // col_size * 4 | |||||
| mul w16, w6, w11 // channel * 4 | |||||
| mul w17, w5, w11 // row_size * 4 | |||||
| mov w11, #2 | |||||
| mul w14, w11, w15 // col_size * 2 * 4 | |||||
| movi v23.4s, #6 | |||||
| scvtf v23.4s, v23.4s | |||||
| dup v24.4s, wzr | |||||
| // Load weights | |||||
| ld1 {v0.4s}, [x2], x16 | |||||
| ld1 {v1.4s}, [x2], x16 | |||||
| ld1 {v2.4s}, [x2], x16 | |||||
| ld1 {v3.4s}, [x2], x16 | |||||
| ld1 {v4.4s}, [x2], x16 | |||||
| ld1 {v5.4s}, [x2], x16 | |||||
| ld1 {v6.4s}, [x2], x16 | |||||
| ld1 {v7.4s}, [x2], x16 | |||||
| ld1 {v8.4s}, [x2], x16 | |||||
| mov x11, x1 | |||||
| add x12, x11, x17 | |||||
| add x13, x12, x17 | |||||
| ld1 {v9.4s}, [x11], x15 | |||||
| ld1 {v10.4s}, [x11], x15 | |||||
| ld1 {v11.4s}, [x11], x15 | |||||
| ld1 {v13.4s}, [x12], x15 | |||||
| ld1 {v14.4s}, [x12], x15 | |||||
| ld1 {v15.4s}, [x12], x15 | |||||
| ld1 {v17.4s}, [x13], x15 | |||||
| ld1 {v18.4s}, [x13], x15 | |||||
| ld1 {v19.4s}, [x13], x15 | |||||
| ld1 {v21.4s}, [x3] | |||||
| ld1 {v22.4s}, [x3] | |||||
| cmp w8, #2 | |||||
| beq WIDTH2_LEFT | |||||
| cmp w8, #1 | |||||
| beq WIDTH1_LEFT | |||||
| WIDTH2_LOOP: | |||||
| fmla v21.4s, v0.4s, v9.4s | |||||
| ld1 {v12.4s}, [x11] | |||||
| ld1 {v16.4s}, [x12] | |||||
| fmla v22.4s, v0.4s, v10.4s | |||||
| ld1 {v20.4s}, [x13] | |||||
| add x1, x1, x14 | |||||
| fmla v21.4s, v1.4s, v10.4s | |||||
| mov x11, x1 | |||||
| add x12, x11, x17 | |||||
| add x13, x12, x17 | |||||
| ld1 {v9.4s}, [x11], x15 | |||||
| fmla v22.4s, v1.4s, v11.4s | |||||
| ld1 {v10.4s}, [x11], x15 | |||||
| fmla v21.4s, v2.4s, v11.4s | |||||
| fmla v22.4s, v2.4s, v12.4s | |||||
| fmla v21.4s, v3.4s, v13.4s | |||||
| ld1 {v11.4s}, [x11], x15 | |||||
| fmla v22.4s, v3.4s, v14.4s | |||||
| fmla v21.4s, v4.4s, v14.4s | |||||
| ld1 {v13.4s}, [x12], x15 | |||||
| fmla v22.4s, v4.4s, v15.4s | |||||
| fmla v21.4s, v5.4s, v15.4s | |||||
| ld1 {v14.4s}, [x12], x15 | |||||
| fmla v22.4s, v5.4s, v16.4s | |||||
| fmla v21.4s, v6.4s, v17.4s | |||||
| ld1 {v15.4s}, [x12], x15 | |||||
| fmla v22.4s, v6.4s, v18.4s | |||||
| fmla v21.4s, v7.4s, v18.4s | |||||
| ld1 {v17.4s}, [x13], x15 | |||||
| fmla v22.4s, v7.4s, v19.4s | |||||
| fmla v21.4s, v8.4s, v19.4s | |||||
| ld1 {v18.4s}, [x13], x15 | |||||
| fmla v22.4s, v8.4s, v20.4s | |||||
| ld1 {v19.4s}, [x13], x15 | |||||
| cbnz x10, WIDTH2_RELU6 | |||||
| cbnz x9, WIDTH2_RELU | |||||
| b WIDTH2_WRITE | |||||
| WIDTH2_RELU6: | |||||
| fmin v21.4s, v21.4s, v23.4s | |||||
| fmin v22.4s, v22.4s, v23.4s | |||||
| WIDTH2_RELU: | |||||
| fmax v21.4s, v21.4s, v24.4s | |||||
| fmax v22.4s, v22.4s, v24.4s | |||||
| WIDTH2_WRITE: | |||||
| st1 {v21.4s}, [x0], x16 | |||||
| ld1 {v21.4s}, [x3] | |||||
| st1 {v22.4s}, [x0], x16 | |||||
| ld1 {v22.4s}, [x3] | |||||
| sub w8, w8, #2 | |||||
| cmp w8, #2 | |||||
| bgt WIDTH2_LOOP | |||||
| cmp w8, #2 | |||||
| blt WIDTH1_LEFT | |||||
| WIDTH2_LEFT: | |||||
| fmla v21.4s, v0.4s, v9.4s | |||||
| ld1 {v12.4s}, [x11] | |||||
| fmla v22.4s, v0.4s, v10.4s | |||||
| fmla v21.4s, v1.4s, v10.4s | |||||
| ld1 {v16.4s}, [x12] | |||||
| fmla v22.4s, v1.4s, v11.4s | |||||
| fmla v21.4s, v2.4s, v11.4s | |||||
| ld1 {v20.4s}, [x13] | |||||
| fmla v22.4s, v2.4s, v12.4s | |||||
| fmla v21.4s, v3.4s, v13.4s | |||||
| fmla v22.4s, v3.4s, v14.4s | |||||
| fmla v21.4s, v4.4s, v14.4s | |||||
| fmla v22.4s, v4.4s, v15.4s | |||||
| fmla v21.4s, v5.4s, v15.4s | |||||
| fmla v22.4s, v5.4s, v16.4s | |||||
| fmla v21.4s, v6.4s, v17.4s | |||||
| fmla v22.4s, v6.4s, v18.4s | |||||
| fmla v21.4s, v7.4s, v18.4s | |||||
| fmla v22.4s, v7.4s, v19.4s | |||||
| fmla v21.4s, v8.4s, v19.4s | |||||
| fmla v22.4s, v8.4s, v20.4s | |||||
| cbnz x10, WIDTH2_LEFT_RELU6 | |||||
| cbnz x9, WIDTH2_LEFT_RELU | |||||
| b WIDTH2_LEFT_WRITE | |||||
| WIDTH2_LEFT_RELU6: | |||||
| fmin v21.4s, v21.4s, v23.4s | |||||
| fmin v22.4s, v22.4s, v23.4s | |||||
| WIDTH2_LEFT_RELU: | |||||
| fmax v21.4s, v21.4s, v24.4s | |||||
| fmax v22.4s, v22.4s, v24.4s | |||||
| WIDTH2_LEFT_WRITE: | |||||
| st1 {v21.4s}, [x0], x16 | |||||
| st1 {v22.4s}, [x0], x16 | |||||
| b End | |||||
| WIDTH1_LEFT: | |||||
| fmla v21.4s, v0.4s, v9.4s | |||||
| fmla v21.4s, v1.4s, v10.4s | |||||
| fmla v21.4s, v2.4s, v11.4s | |||||
| fmla v21.4s, v3.4s, v13.4s | |||||
| fmla v21.4s, v4.4s, v14.4s | |||||
| fmla v21.4s, v5.4s, v15.4s | |||||
| fmla v21.4s, v6.4s, v17.4s | |||||
| fmla v21.4s, v7.4s, v18.4s | |||||
| fmla v21.4s, v8.4s, v19.4s | |||||
| cbnz x10, WIDTH1_RELU6 | |||||
| cbnz x9, WIDTH1_RELU | |||||
| b WIDTH1_WRITE | |||||
| WIDTH1_RELU6: | |||||
| fmin v21.4s, v21.4s, v23.4s | |||||
| WIDTH1_RELU: | |||||
| fmax v21.4s, v21.4s, v24.4s | |||||
| WIDTH1_WRITE: | |||||
| st1 {v21.4s}, [x0] | |||||
| End: | |||||
| sub sp, sp, #128 | |||||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||||
| ret | |||||
| #endif | |||||
| @@ -0,0 +1,201 @@ | |||||
| #ifdef __aarch64__ | |||||
| .text | |||||
| .align 5 | |||||
| .global ConvDw3x3Stride2 | |||||
| #ifndef __APPLE__ | |||||
| .type ConvDw3x3Stride2, %function | |||||
| #endif | |||||
| // void ConvDw3x3Stride2(float *output, const float *buffer, const float *weight, const float *bias, int col_size, | |||||
| // int row_size, int channel, int output_h, int output_w, size_t relu, size_t relu6) | |||||
| // | |||||
| // x0: output | |||||
| // x1: input | |||||
| // x2: weight | |||||
| // x3: bias | |||||
| // w4: col_size | |||||
| // w5: row_size | |||||
| // w6: channel | |||||
| // w7: output_h | |||||
| // w8: output_w | |||||
| // w9: relu | |||||
| // w10: relu6 | |||||
| ConvDw3x3Stride2: | |||||
| sub sp, sp, #128 | |||||
| st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||||
| ldr w8, [sp] | |||||
| ldr w9, [sp, #8] | |||||
| ldr w10, [sp, #16] | |||||
| mov w11, #4 | |||||
| mul w15, w4, w11 // col_size * 4 | |||||
| mul w16, w6, w11 // channel * 4 | |||||
| mul w17, w5, w11 // row_size * 4 | |||||
| mov w11, #2 | |||||
| mul w14, w11, w15 // col_size * 2 * 4 | |||||
| movi v26.4s, #6 | |||||
| scvtf v26.4s, v26.4s | |||||
| dup v27.4s, wzr | |||||
| // Load weights | |||||
| ld1 {v0.4s}, [x2], x16 | |||||
| ld1 {v1.4s}, [x2], x16 | |||||
| ld1 {v2.4s}, [x2], x16 | |||||
| ld1 {v3.4s}, [x2], x16 | |||||
| ld1 {v4.4s}, [x2], x16 | |||||
| ld1 {v5.4s}, [x2], x16 | |||||
| ld1 {v6.4s}, [x2], x16 | |||||
| ld1 {v7.4s}, [x2], x16 | |||||
| ld1 {v8.4s}, [x2], x16 | |||||
| mov x11, x1 | |||||
| add x12, x11, x17 | |||||
| add x13, x12, x17 | |||||
| ld1 {v9.4s}, [x11], x15 | |||||
| ld1 {v10.4s}, [x11], x15 | |||||
| ld1 {v11.4s}, [x11], x15 | |||||
| ld1 {v14.4s}, [x12], x15 | |||||
| ld1 {v15.4s}, [x12], x15 | |||||
| ld1 {v16.4s}, [x12], x15 | |||||
| ld1 {v19.4s}, [x13], x15 | |||||
| ld1 {v20.4s}, [x13], x15 | |||||
| ld1 {v21.4s}, [x13], x15 | |||||
| ld1 {v24.4s}, [x3] | |||||
| ld1 {v25.4s}, [x3] | |||||
| cmp w8, #2 | |||||
| beq WIDTH2_LEFT | |||||
| cmp w8, #1 | |||||
| beq WIDTH1_LEFT | |||||
| WIDTH2_LOOP: | |||||
| fmla v24.4s, v0.4s, v9.4s | |||||
| ld1 {v12.4s}, [x11], x15 | |||||
| fmla v25.4s, v0.4s, v11.4s | |||||
| ld1 {v17.4s}, [x12], x15 | |||||
| fmla v24.4s, v1.4s, v10.4s | |||||
| ld1 {v22.4s}, [x13], x15 | |||||
| fmla v25.4s, v1.4s, v12.4s | |||||
| ld1 {v13.4s}, [x11], x15 | |||||
| fmla v24.4s, v2.4s, v11.4s | |||||
| ld1 {v18.4s}, [x12], x15 | |||||
| fmla v25.4s, v2.4s, v13.4s | |||||
| ld1 {v23.4s}, [x13], x15 | |||||
| fmla v24.4s, v3.4s, v14.4s | |||||
| mov v9.16b, v13.16b | |||||
| fmla v25.4s, v3.4s, v16.4s | |||||
| mov v14.16b, v18.16b | |||||
| fmla v24.4s, v4.4s, v15.4s | |||||
| fmla v25.4s, v4.4s, v17.4s | |||||
| ld1 {v10.4s}, [x11], x15 | |||||
| fmla v24.4s, v5.4s, v16.4s | |||||
| ld1 {v11.4s}, [x11], x15 | |||||
| fmla v25.4s, v5.4s, v18.4s | |||||
| ld1 {v15.4s}, [x12], x15 | |||||
| fmla v24.4s, v6.4s, v19.4s | |||||
| ld1 {v16.4s}, [x12], x15 | |||||
| fmla v25.4s, v6.4s, v21.4s | |||||
| mov v19.16b, v23.16b | |||||
| fmla v24.4s, v7.4s, v20.4s | |||||
| fmla v25.4s, v7.4s, v22.4s | |||||
| ld1 {v20.4s}, [x13], x15 | |||||
| fmla v24.4s, v8.4s, v21.4s | |||||
| fmla v25.4s, v8.4s, v23.4s | |||||
| ld1 {v21.4s}, [x13], x15 | |||||
| cbnz x10, WIDTH2_RELU6 | |||||
| cbnz x9, WIDTH2_RELU | |||||
| b WIDTH2_WRITE | |||||
| WIDTH2_RELU6: | |||||
| fmin v24.4s, v24.4s, v26.4s | |||||
| fmin v25.4s, v25.4s, v26.4s | |||||
| WIDTH2_RELU: | |||||
| fmax v24.4s, v24.4s, v27.4s | |||||
| fmax v25.4s, v25.4s, v27.4s | |||||
| WIDTH2_WRITE: | |||||
| st1 {v24.4s}, [x0], x16 | |||||
| ld1 {v24.4s}, [x3] | |||||
| st1 {v25.4s}, [x0], x16 | |||||
| ld1 {v25.4s}, [x3] | |||||
| sub w8, w8, #2 | |||||
| cmp w8, #2 | |||||
| bgt WIDTH2_LOOP | |||||
| cmp w8, #2 | |||||
| blt WIDTH1_LEFT | |||||
| WIDTH2_LEFT: | |||||
| fmla v24.4s, v0.4s, v9.4s | |||||
| ld1 {v12.4s}, [x11], x15 | |||||
| fmla v25.4s, v0.4s, v11.4s | |||||
| ld1 {v17.4s}, [x12], x15 | |||||
| fmla v24.4s, v1.4s, v10.4s | |||||
| ld1 {v22.4s}, [x13], x15 | |||||
| fmla v25.4s, v1.4s, v12.4s | |||||
| ld1 {v13.4s}, [x11], x15 | |||||
| fmla v24.4s, v2.4s, v11.4s | |||||
| ld1 {v18.4s}, [x12], x15 | |||||
| fmla v25.4s, v2.4s, v13.4s | |||||
| ld1 {v23.4s}, [x13], x15 | |||||
| fmla v24.4s, v3.4s, v14.4s | |||||
| fmla v25.4s, v3.4s, v16.4s | |||||
| fmla v24.4s, v4.4s, v15.4s | |||||
| fmla v25.4s, v4.4s, v17.4s | |||||
| fmla v24.4s, v5.4s, v16.4s | |||||
| fmla v25.4s, v5.4s, v18.4s | |||||
| fmla v24.4s, v6.4s, v19.4s | |||||
| fmla v25.4s, v6.4s, v21.4s | |||||
| fmla v24.4s, v7.4s, v20.4s | |||||
| fmla v25.4s, v7.4s, v22.4s | |||||
| fmla v24.4s, v8.4s, v21.4s | |||||
| fmla v25.4s, v8.4s, v23.4s | |||||
| cbnz x10, WIDTH2_LEFT_RELU6 | |||||
| cbnz x9, WIDTH2_LEFT_RELU | |||||
| b WIDTH2_LEFT_WRITE | |||||
| WIDTH2_LEFT_RELU6: | |||||
| fmin v24.4s, v24.4s, v26.4s | |||||
| fmin v25.4s, v25.4s, v26.4s | |||||
| WIDTH2_LEFT_RELU: | |||||
| fmax v24.4s, v24.4s, v27.4s | |||||
| fmax v25.4s, v25.4s, v27.4s | |||||
| WIDTH2_LEFT_WRITE: | |||||
| st1 {v24.4s}, [x0], x16 | |||||
| st1 {v25.4s}, [x0], x16 | |||||
| b End | |||||
| WIDTH1_LEFT: | |||||
| fmla v24.4s, v0.4s, v9.4s | |||||
| fmla v24.4s, v1.4s, v10.4s | |||||
| fmla v24.4s, v2.4s, v11.4s | |||||
| fmla v24.4s, v3.4s, v14.4s | |||||
| fmla v24.4s, v4.4s, v15.4s | |||||
| fmla v24.4s, v5.4s, v16.4s | |||||
| fmla v24.4s, v6.4s, v19.4s | |||||
| fmla v24.4s, v7.4s, v20.4s | |||||
| fmla v24.4s, v8.4s, v21.4s | |||||
| cbnz x10, WIDTH1_RELU6 | |||||
| cbnz x9, WIDTH1_RELU | |||||
| b WIDTH1_WRITE | |||||
| WIDTH1_RELU6: | |||||
| fmin v24.4s, v24.4s, v26.4s | |||||
| WIDTH1_RELU: | |||||
| fmax v24.4s, v24.4s, v27.4s | |||||
| WIDTH1_WRITE: | |||||
| st1 {v24.4s}, [x0] | |||||
| End: | |||||
| sub sp, sp, #128 | |||||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||||
| ret | |||||
| #endif | |||||
| @@ -0,0 +1,114 @@ | |||||
| #ifdef __aarch64__ | |||||
| .text | |||||
| .align 5 | |||||
| .global ConvDw3x3Vertical | |||||
| #ifndef __APPLE__ | |||||
| .type ConvDw3x3Vertical, %function | |||||
| #endif | |||||
| // void ConvDw3x3Vertical(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, | |||||
| // int in_kw_step, int channel, size_t relu, size_t relu6) | |||||
| // x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, x6: channel, x7: relu, x8: relu6 | |||||
| ConvDw3x3Vertical: | |||||
| // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to | |||||
| // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | |||||
| // x19 ~ x29 should be also preserved | |||||
| // whereas our coding style do not permit such amount of parameters | |||||
| ldr x8, [sp] | |||||
| mov x9, #4 | |||||
| mul x13, x6, x9 // x6 * 4 | |||||
| mul x4, x4, x9 | |||||
| mul x5, x5, x9 | |||||
| mov x9, #3 | |||||
| mul x14, x13, x9 // x6 * 3 * 4 | |||||
| movi v26.4s, #6 | |||||
| scvtf v26.4s, v26.4s | |||||
| dup v27.4s, wzr | |||||
| ld1 {v23.4s}, [x3], #16 | |||||
| mov x9, x1 | |||||
| mov x10, x2 | |||||
| ld1 {v0.4s}, [x9], x5 | |||||
| add x11, x1, x4 | |||||
| ld1 {v4.4s}, [x10], x13 | |||||
| ld1 {v1.4s}, [x9], x5 | |||||
| add x12, x2, x14 | |||||
| ld1 {v5.4s}, [x10], x13 | |||||
| ld1 {v2.4s}, [x11], x5 | |||||
| ld1 {v6.4s}, [x12], x13 | |||||
| ld1 {v3.4s}, [x11], x5 | |||||
| ld1 {v7.4s}, [x12], x13 | |||||
| ld1 {v16.4s}, [x9], x5 | |||||
| ld1 {v18.4s}, [x10], x13 | |||||
| ld1 {v17.4s}, [x11], x5 | |||||
| ld1 {v19.4s}, [x12], x13 | |||||
| cmp x6, #4 | |||||
| ble LoopC4Post | |||||
| LoopC4: | |||||
| add x1, x1, #16 | |||||
| add x2, x2, #16 | |||||
| fmla v23.4s, v0.4s, v4.4s | |||||
| mov x9, x1 | |||||
| mov x10, x2 | |||||
| ld1 {v0.4s}, [x9], x5 | |||||
| ld1 {v4.4s}, [x10], x13 | |||||
| add x11, x1, x4 | |||||
| fmla v23.4s, v1.4s, v5.4s | |||||
| add x12, x2, x14 | |||||
| ld1 {v1.4s}, [x9], x5 | |||||
| fmla v23.4s, v2.4s, v6.4s | |||||
| ld1 {v5.4s}, [x10], x13 | |||||
| ld1 {v2.4s}, [x11], x5 | |||||
| fmla v23.4s, v3.4s, v7.4s | |||||
| ld1 {v6.4s}, [x12], x13 | |||||
| ld1 {v3.4s}, [x11], x5 | |||||
| fmla v23.4s, v16.4s, v18.4s | |||||
| ld1 {v7.4s}, [x12], x13 | |||||
| ld1 {v16.4s}, [x9], x5 | |||||
| fmla v23.4s, v17.4s, v19.4s | |||||
| ld1 {v18.4s}, [x10], x13 | |||||
| ld1 {v17.4s}, [x11], x5 | |||||
| ld1 {v19.4s}, [x12], x13 | |||||
| cbnz x8, C4_RELU6 | |||||
| cbnz x7, C4_RELU | |||||
| b C4_WRITE | |||||
| C4_RELU6: | |||||
| fmin v23.4s, v23.4s, v26.4s | |||||
| C4_RELU: | |||||
| fmax v23.4s, v23.4s, v27.4s | |||||
| C4_WRITE: | |||||
| st1 {v23.4s}, [x0], #16 | |||||
| ld1 {v23.4s}, [x3], #16 | |||||
| sub x6, x6, #4 | |||||
| cmp x6, #4 | |||||
| bgt LoopC4 | |||||
| LoopC4Post: | |||||
| fmla v23.4s, v0.4s, v4.4s | |||||
| fmla v23.4s, v1.4s, v5.4s | |||||
| fmla v23.4s, v2.4s, v6.4s | |||||
| fmla v23.4s, v3.4s, v7.4s | |||||
| fmla v23.4s, v16.4s, v18.4s | |||||
| fmla v23.4s, v17.4s, v19.4s | |||||
| cbnz x8, RELU6 | |||||
| cbnz x7, RELU | |||||
| b WRITE | |||||
| RELU6: | |||||
| fmin v23.4s, v23.4s, v26.4s | |||||
| RELU: | |||||
| fmax v23.4s, v23.4s, v27.4s | |||||
| WRITE: | |||||
| st1 {v23.4s}, [x0], #16 | |||||
| ret | |||||
| #endif | |||||
| @@ -71,6 +71,21 @@ void ConvSwFp32Center(float *dst, const float *src, const float *weight, const f | |||||
| size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t ic4, | size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t ic4, | ||||
| size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, | size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, | ||||
| size_t relu6); | size_t relu6); | ||||
| void ConvDw3x3Stride1(float *output, const float *buffer, const float *weight, const float *bias, int col_size, | |||||
| int row_size, int channel, int output_h, int output_w, size_t relu, size_t relu6); | |||||
| void ConvDw3x3Stride2(float *output, const float *buffer, const float *weight, const float *bias, int col_size, | |||||
| int row_size, int channel, int output_h, int output_w, size_t relu, size_t relu6); | |||||
| void ConvDw3x3Corner(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, | |||||
| int in_kw_step, int channel, size_t relu, size_t relu6); | |||||
| void ConvDw3x3Vertical(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, | |||||
| int in_kw_step, int channel, size_t relu, size_t relu6); | |||||
| void ConvDw3x3Horizontal(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, | |||||
| int in_kw_step, int channel, size_t relu, size_t relu6); | |||||
| #endif | #endif | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| @@ -304,6 +304,279 @@ void ConvDwSWFp32(float *output_data, const float *input_data, const float *weig | |||||
| } | } | ||||
| /*conv depthwise fp32 end*/ | /*conv depthwise fp32 end*/ | ||||
| /*conv depthwise 3x3 fp32 begin*/ | |||||
| bool CheckConvDwUse3X3(const ConvParameter *conv_param) { | |||||
| bool use_3x3 = | |||||
| conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && | |||||
| (conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) && | |||||
| (conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) && conv_param->stride_h_ == conv_param->stride_w_ && | |||||
| (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) && (conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) && | |||||
| conv_param->pad_u_ == conv_param->pad_l_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1; | |||||
| if (!use_3x3 || conv_param->input_h_ == 1 || conv_param->input_w_ == 1) { | |||||
| return false; | |||||
| } | |||||
| const int in_h = (conv_param->output_h_ - 1) * conv_param->stride_h_ + conv_param->kernel_h_; | |||||
| const int in_w = (conv_param->output_w_ - 1) * conv_param->stride_w_ + conv_param->kernel_w_; | |||||
| return in_h == (conv_param->input_h_ + 2 * conv_param->pad_u_) && | |||||
| in_w == (conv_param->input_w_ + 2 * conv_param->pad_l_); | |||||
| } | |||||
| void ConvDw3x3BorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width, | |||||
| int in_kh_step, int in_kw_step, int channel, bool relu, bool relu6) { | |||||
| for (int c = 0; c < channel; c += C4NUM) { | |||||
| for (int i = 0; i < C4NUM; i++) { | |||||
| dst[i] = 0; | |||||
| } | |||||
| const float *src_kh = src; | |||||
| const float *weight_kh = weight; | |||||
| for (int kh = 0; kh < height; kh++) { | |||||
| const float *src_kw = src_kh; | |||||
| const float *weight_kw = weight_kh; | |||||
| for (int kw = 0; kw < width; kw++) { | |||||
| for (int i = 0; i < C4NUM; i++) { | |||||
| dst[i] += src_kw[c + i] * weight_kw[c + i]; | |||||
| } | |||||
| src_kw += in_kw_step; | |||||
| weight_kw += channel; | |||||
| } // kernel_w loop | |||||
| src_kh += in_kh_step; | |||||
| weight_kh += 3 * channel; | |||||
| } // kernel_h loop | |||||
| for (int i = 0; i < C4NUM; i++) { | |||||
| dst[i] += bias[c + i]; | |||||
| dst[i] = (relu) ? (MSMAX(0, dst[i])) : (dst[i]); | |||||
| dst[i] = (relu6) ? (MSMIN(6, MSMAX(0, dst[i]))) : (dst[i]); | |||||
| } | |||||
| dst += C4NUM; | |||||
| } | |||||
| } | |||||
| #ifndef ENABLE_ARM64 | |||||
| void ConvDw3x3Corner(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, | |||||
| int in_kw_step, int channel, bool relu, bool relu6) { | |||||
| ConvDw3x3BorderPixel(dst, src, weight, bias, 2, 2, in_kh_step, in_kw_step, channel, relu, relu6); | |||||
| } | |||||
| void ConvDw3x3Vertical(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, | |||||
| int in_kw_step, int channel, bool relu, bool relu6) { | |||||
| ConvDw3x3BorderPixel(dst, src, weight, bias, 2, 3, in_kh_step, in_kw_step, channel, relu, relu6); | |||||
| } | |||||
| void ConvDw3x3Horizontal(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, | |||||
| int in_kw_step, int channel, bool relu, bool relu6) { | |||||
| ConvDw3x3BorderPixel(dst, src, weight, bias, 3, 2, in_kh_step, in_kw_step, channel, relu, relu6); | |||||
| } | |||||
| #endif | |||||
| void ConvDw3x3Pad(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, | |||||
| const ConvParameter *conv_param, const SlidingWindowParam *sliding) { | |||||
| int input_row_size = conv_param->input_w_ * conv_param->input_channel_; | |||||
| int weight_row_size = conv_param->kernel_w_ * conv_param->input_channel_; | |||||
| int output_row_size = conv_param->output_w_ * conv_param->output_channel_; | |||||
| int in_kh_step = sliding->in_kh_step_; | |||||
| int in_kw_step = sliding->in_kw_step_; | |||||
| bool relu = conv_param->act_type_ == ActType_Relu; | |||||
| bool relu6 = conv_param->act_type_ == ActType_Relu6; | |||||
| for (int b = 0; b < conv_param->output_batch_; b++) { | |||||
| const float *input_batch = | |||||
| input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; | |||||
| float *output_batch = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; | |||||
| // top | |||||
| const float *input = input_batch; | |||||
| const float *weight = weight_data + weight_row_size + conv_param->input_channel_; | |||||
| float *output = output_batch; | |||||
| ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6); | |||||
| input += (conv_param->stride_w_ - 1) * conv_param->input_channel_; | |||||
| weight = weight_data + weight_row_size; | |||||
| output += conv_param->output_channel_; | |||||
| for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { | |||||
| ConvDw3x3Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, | |||||
| relu6); | |||||
| input += conv_param->stride_w_ * conv_param->input_channel_; | |||||
| output += conv_param->output_channel_; | |||||
| } | |||||
| ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6); | |||||
| // left | |||||
| input = input_batch + (conv_param->stride_h_ - 1) * input_row_size; | |||||
| weight = weight_data + conv_param->input_channel_; | |||||
| output = output_batch + output_row_size; | |||||
| for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { | |||||
| ConvDw3x3Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, | |||||
| relu6); | |||||
| input += conv_param->stride_h_ * input_row_size; | |||||
| output += output_row_size; | |||||
| } | |||||
| // right | |||||
| input = input_batch + (conv_param->input_w_ - 2) * conv_param->input_channel_ + | |||||
| (conv_param->stride_h_ - 1) * input_row_size; | |||||
| weight = weight_data; | |||||
| output = output_batch + output_row_size + (conv_param->output_w_ - 1) * conv_param->output_channel_; | |||||
| for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { | |||||
| ConvDw3x3Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, | |||||
| relu6); | |||||
| input += conv_param->stride_h_ * input_row_size; | |||||
| output += output_row_size; | |||||
| } | |||||
| // bottom | |||||
| input = input_batch + (conv_param->input_h_ - 2) * input_row_size; | |||||
| weight = weight_data + conv_param->input_channel_; | |||||
| output = output_batch + (conv_param->output_h_ - 1) * output_row_size; | |||||
| ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6); | |||||
| input += conv_param->stride_w_ == 1 ? 0 : conv_param->input_channel_; | |||||
| weight = weight_data; | |||||
| output += conv_param->output_channel_; | |||||
| for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { | |||||
| ConvDw3x3Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, | |||||
| relu6); | |||||
| input += conv_param->stride_w_ * conv_param->input_channel_; | |||||
| output += conv_param->output_channel_; | |||||
| } | |||||
| ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6); | |||||
| } | |||||
| } | |||||
| void ConvDw3x3InitBuffer(float *buffer, const float *input, const ConvParameter *conv_param, int block_input_h, | |||||
| int block_input_w) { | |||||
| for (int h = 0; h < block_input_h; h++) { | |||||
| const float *src = input; | |||||
| for (int w = 0; w < block_input_w; w++) { | |||||
| memcpy(buffer, src, 64 * sizeof(float)); | |||||
| src += conv_param->input_channel_; | |||||
| buffer += 64; | |||||
| } | |||||
| input += conv_param->input_w_ * conv_param->input_channel_; | |||||
| } | |||||
| } | |||||
| void ConvDw3x3Window(float *output, const float *buffer, const float *weight, const float *bias, int col_size, | |||||
| int row_size, int channel, int output_h, int output_w, int stride, bool relu, bool relu6) { | |||||
| for (int w = 0; w < output_w; w++) { | |||||
| for (int i = 0; i < C4NUM; i++) { | |||||
| output[i] = bias[i]; | |||||
| } | |||||
| const float *src_kh = buffer; | |||||
| const float *weight_kh = weight; | |||||
| for (int kh = 0; kh < 3; kh++) { | |||||
| const float *src_kw = src_kh; | |||||
| const float *weight_kw = weight_kh; | |||||
| for (int kw = 0; kw < 3; kw++) { | |||||
| for (int c = 0; c < C4NUM; c++) { | |||||
| output[c] += src_kw[c] * weight_kw[c]; | |||||
| } | |||||
| src_kw += col_size; | |||||
| weight_kw += channel; | |||||
| } | |||||
| src_kh += row_size; | |||||
| weight_kh += 3 * channel; | |||||
| } | |||||
| for (int i = 0; i < C4NUM; i++) { | |||||
| output[i] = (relu) ? (MSMAX(0, output[i])) : (output[i]); | |||||
| output[i] = (relu6) ? (MSMIN(6, MSMAX(0, output[i]))) : (output[i]); | |||||
| } | |||||
| output += channel; | |||||
| buffer += col_size * stride; | |||||
| } | |||||
| } | |||||
| void ConvDw3x3Block(float *output, const float *buffer, const float *weight, const float *bias, int start_c, int end_c, | |||||
| int col_size, int row_size, int channel, int output_h, int output_w, int stride, bool relu, | |||||
| bool relu6) { | |||||
| for (; start_c <= end_c - C4NUM; start_c += C4NUM) { | |||||
| #ifdef ENABLE_ARM64 | |||||
| if (stride == 1) { | |||||
| ConvDw3x3Stride1(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, relu, relu6); | |||||
| } else { | |||||
| ConvDw3x3Stride2(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, relu, relu6); | |||||
| } | |||||
| #else | |||||
| ConvDw3x3Window(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, stride, relu, relu6); | |||||
| #endif | |||||
| output += C4NUM; | |||||
| buffer += C4NUM; | |||||
| weight += C4NUM; | |||||
| bias += C4NUM; | |||||
| } | |||||
| } | |||||
| void ConvDw3x3Row(float *output, float *buffer, const float *input, const float *weight, const float *bias, | |||||
| const ConvParameter *conv_param, int start_w, int end_w, int block_output_h, int block_output_w, | |||||
| int block_input_h, int block_input_w) { | |||||
| bool relu = conv_param->act_type_ == ActType_Relu; | |||||
| bool relu6 = conv_param->act_type_ == ActType_Relu6; | |||||
| const int ih_offset = 64 * block_input_w; | |||||
| int w = start_w; | |||||
| if (conv_param->output_channel_ > 64 || (conv_param->output_channel_ < 64 && conv_param->input_w_ > 150)) { | |||||
| for (; w <= end_w - block_output_w; w += block_output_w) { | |||||
| float *output_ptr = output; | |||||
| const float *input_ptr = input; | |||||
| const float *weight_ptr = weight; | |||||
| const float *bias_ptr = bias; | |||||
| int c = 0; | |||||
| for (; c <= conv_param->output_channel_ - 64; c += 64) { | |||||
| ConvDw3x3InitBuffer(buffer, input_ptr, conv_param, block_input_h, block_input_w); | |||||
| ConvDw3x3Block(output_ptr, buffer, weight_ptr, bias_ptr, 0, 64, 64, ih_offset, conv_param->input_channel_, | |||||
| block_output_h, block_output_w, conv_param->stride_h_, relu, relu6); | |||||
| output_ptr += 64; | |||||
| input_ptr += 64; | |||||
| weight_ptr += 64; | |||||
| bias_ptr += 64; | |||||
| } | |||||
| // left channel | |||||
| ConvDw3x3Block(output_ptr, input_ptr, weight_ptr, bias_ptr, c, conv_param->input_channel_, | |||||
| conv_param->input_channel_, conv_param->input_w_ * conv_param->input_channel_, | |||||
| conv_param->input_channel_, block_output_h, block_output_w, conv_param->stride_h_, relu, relu6); | |||||
| output += block_output_w * conv_param->input_channel_; | |||||
| input += conv_param->stride_w_ * block_output_w * conv_param->input_channel_; | |||||
| } | |||||
| } | |||||
| // left width | |||||
| int left_width = end_w - w; | |||||
| if (left_width > 0) { | |||||
| ConvDw3x3Block(output, input, weight, bias, 0, conv_param->input_channel_, conv_param->input_channel_, | |||||
| conv_param->input_w_ * conv_param->input_channel_, conv_param->input_channel_, block_output_h, | |||||
| left_width, conv_param->stride_h_, relu, relu6); | |||||
| } | |||||
| } | |||||
| void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data, | |||||
| const float *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, | |||||
| int task_id) { | |||||
| int output_h = sliding->bottom_ - sliding->top_; | |||||
| int step_oh = UP_DIV(output_h, conv_param->thread_num_); | |||||
| int start_oh = step_oh * task_id + sliding->top_; | |||||
| int end_oh = MSMIN(start_oh + step_oh, sliding->bottom_); | |||||
| int start_ow = sliding->left_; | |||||
| int end_ow = sliding->right_; | |||||
| const int block_output_h = 1; | |||||
| int block_output_w = conv_param->stride_w_ == 1 ? 30 : 14; | |||||
| const int block_input_h = 3; | |||||
| int block_input_w = conv_param->stride_w_ * (block_output_w - 1) + 3; | |||||
| for (int b = 0; b < conv_param->output_batch_; b++) { | |||||
| int start_ih = start_oh * conv_param->stride_h_ - conv_param->pad_u_; | |||||
| int start_iw = start_ow * conv_param->stride_w_ - conv_param->pad_l_; | |||||
| const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_ + | |||||
| start_ih * conv_param->input_w_ * conv_param->input_channel_ + | |||||
| start_iw * conv_param->input_channel_; | |||||
| float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_ + | |||||
| start_oh * conv_param->output_w_ * conv_param->output_channel_ + | |||||
| start_ow * conv_param->output_channel_; | |||||
| for (int oh = start_oh; oh < end_oh; oh++) { | |||||
| ConvDw3x3Row(dst, buffer, src, weight_data, bias_data, conv_param, start_ow, end_ow, block_output_h, | |||||
| block_output_w, block_input_h, block_input_w); | |||||
| src += conv_param->stride_h_ * conv_param->input_w_ * conv_param->input_channel_; | |||||
| dst += conv_param->output_w_ * conv_param->output_channel_; | |||||
| } | |||||
| } | |||||
| } | |||||
| /*conv depthwise 3x3 fp32 end*/ | |||||
| /*deconv depthwise fp32 begin*/ | /*deconv depthwise fp32 begin*/ | ||||
| void DeconvDwBorderPixel(float *dst, const float *src, const float *weight, int height, int width, int in_kh_step, | void DeconvDwBorderPixel(float *dst, const float *src, const float *weight, int height, int width, int in_kh_step, | ||||
| int in_kw_step, int kernel_w_step) { | int in_kw_step, int kernel_w_step) { | ||||
| @@ -45,6 +45,14 @@ void AppendSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter * | |||||
| void ConvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, | void ConvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, | ||||
| const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); | const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); | ||||
| bool CheckConvDwUse3X3(const ConvParameter *conv_param); | |||||
| void ConvDw3x3Pad(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, | |||||
| const ConvParameter *conv_param, const SlidingWindowParam *sliding); | |||||
| void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data, | |||||
| const float *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); | |||||
| void DeconvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, | void DeconvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, | ||||
| const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); | const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); | ||||
| @@ -139,25 +139,6 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da | |||||
| /*conv depthwise int8 end*/ | /*conv depthwise int8 end*/ | ||||
| /*conv depthwise 3x3 int8 begin*/ | /*conv depthwise 3x3 int8 begin*/ | ||||
| bool CheckConvDwInt8Use3X3(const ConvParameter *conv_param) { | |||||
| bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && | |||||
| (conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) && | |||||
| (conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) && | |||||
| conv_param->stride_h_ == conv_param->stride_w_ && | |||||
| (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) && | |||||
| (conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) && conv_param->pad_u_ == conv_param->pad_l_ && | |||||
| conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && (conv_param->input_channel_ % 8 == 0); | |||||
| if (!use_3x3) { | |||||
| return false; | |||||
| } | |||||
| const int out_w = conv_param->output_w_ - 1; | |||||
| const int out_h = conv_param->output_h_ - 1; | |||||
| const int in_w = out_w * conv_param->stride_w_ - conv_param->pad_l_ + conv_param->kernel_w_; | |||||
| const int in_h = out_h * conv_param->stride_h_ - conv_param->pad_u_ + conv_param->kernel_h_; | |||||
| use_3x3 = in_w <= (conv_param->input_w_ + conv_param->pad_l_) && in_h <= (conv_param->input_h_ + conv_param->pad_u_); | |||||
| return use_3x3; | |||||
| } | |||||
| void ConvDw3x3Int8InitBuffer(int8_t *buffer, const int8_t *input, const ConvParameter *conv_param, int block_input_h, | void ConvDw3x3Int8InitBuffer(int8_t *buffer, const int8_t *input, const ConvParameter *conv_param, int block_input_h, | ||||
| int block_input_w) { | int block_input_w) { | ||||
| for (int h = 0; h < block_input_h; h++) { | for (int h = 0; h < block_input_h; h++) { | ||||
| @@ -428,63 +409,70 @@ void ConvDw3x3Int8Pad(int8_t *output_data, const int8_t *input_data, const int16 | |||||
| int in_kw_step = sliding->in_kw_step_; | int in_kw_step = sliding->in_kw_step_; | ||||
| // top | // top | ||||
| const int8_t *input = input_data; | |||||
| const int16_t *weight = weight_data + weight_row_size + conv_param->input_channel_; | |||||
| int8_t *output = output_data; | |||||
| ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| input += (conv_param->stride_w_ - 1) * conv_param->input_channel_; | |||||
| weight = weight_data + weight_row_size; | |||||
| output += conv_param->output_channel_; | |||||
| for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { | |||||
| ConvDw3x3Int8Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| input += conv_param->stride_w_ * conv_param->input_channel_; | |||||
| for (int b = 0; b < conv_param->output_batch_; b++) { | |||||
| const int8_t *input_batch = | |||||
| input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; | |||||
| int8_t *output_batch = | |||||
| output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; | |||||
| const int8_t *input = input_batch; | |||||
| const int16_t *weight = weight_data + weight_row_size + conv_param->input_channel_; | |||||
| int8_t *output = output_batch; | |||||
| ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| input += (conv_param->stride_w_ - 1) * conv_param->input_channel_; | |||||
| weight = weight_data + weight_row_size; | |||||
| output += conv_param->output_channel_; | output += conv_param->output_channel_; | ||||
| } | |||||
| ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| // left | |||||
| input = input_data + (conv_param->stride_h_ - 1) * input_row_size; | |||||
| weight = weight_data + conv_param->input_channel_; | |||||
| output = output_data + output_row_size; | |||||
| for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { | |||||
| ConvDw3x3Int8Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||||
| for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { | |||||
| ConvDw3x3Int8Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | ||||
| input += conv_param->stride_h_ * input_row_size; | |||||
| output += output_row_size; | |||||
| } | |||||
| input += conv_param->stride_w_ * conv_param->input_channel_; | |||||
| output += conv_param->output_channel_; | |||||
| } | |||||
| ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| // left | |||||
| input = input_batch + (conv_param->stride_h_ - 1) * input_row_size; | |||||
| weight = weight_data + conv_param->input_channel_; | |||||
| output = output_batch + output_row_size; | |||||
| for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { | |||||
| ConvDw3x3Int8Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, | |||||
| in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| input += conv_param->stride_h_ * input_row_size; | |||||
| output += output_row_size; | |||||
| } | |||||
| // right | |||||
| input = | |||||
| input_data + (conv_param->input_w_ - 2) * conv_param->input_channel_ + (conv_param->stride_h_ - 1) * input_row_size; | |||||
| weight = weight_data; | |||||
| output = output_data + output_row_size + (conv_param->output_w_ - 1) * conv_param->output_channel_; | |||||
| for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { | |||||
| ConvDw3x3Int8Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| input += conv_param->stride_h_ * input_row_size; | |||||
| output += output_row_size; | |||||
| } | |||||
| // right | |||||
| input = input_batch + (conv_param->input_w_ - 2) * conv_param->input_channel_ + | |||||
| (conv_param->stride_h_ - 1) * input_row_size; | |||||
| weight = weight_data; | |||||
| output = output_batch + output_row_size + (conv_param->output_w_ - 1) * conv_param->output_channel_; | |||||
| for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { | |||||
| ConvDw3x3Int8Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, | |||||
| in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| input += conv_param->stride_h_ * input_row_size; | |||||
| output += output_row_size; | |||||
| } | |||||
| // bottom | |||||
| input = input_data + (conv_param->input_h_ - 2) * input_row_size; | |||||
| weight = weight_data + conv_param->input_channel_; | |||||
| output = output_data + (conv_param->output_h_ - 1) * output_row_size; | |||||
| ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| input += conv_param->stride_w_ == 1 ? 0 : conv_param->input_channel_; | |||||
| weight = weight_data; | |||||
| output += conv_param->output_channel_; | |||||
| for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { | |||||
| ConvDw3x3Int8Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| input += conv_param->stride_w_ * conv_param->input_channel_; | |||||
| // bottom | |||||
| input = input_batch + (conv_param->input_h_ - 2) * input_row_size; | |||||
| weight = weight_data + conv_param->input_channel_; | |||||
| output = output_batch + (conv_param->output_h_ - 1) * output_row_size; | |||||
| ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| input += conv_param->stride_w_ == 1 ? 0 : conv_param->input_channel_; | |||||
| weight = weight_data; | |||||
| output += conv_param->output_channel_; | output += conv_param->output_channel_; | ||||
| for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { | |||||
| ConvDw3x3Int8Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| input += conv_param->stride_w_ * conv_param->input_channel_; | |||||
| output += conv_param->output_channel_; | |||||
| } | |||||
| ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| } | } | ||||
| ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| } | } | ||||
| /*conv depthwise 3x3 int8 end*/ | /*conv depthwise 3x3 int8 end*/ | ||||
| @@ -24,8 +24,6 @@ | |||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| bool CheckConvDwInt8Use3X3(const ConvParameter *conv_param); | |||||
| void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data, | void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data, | ||||
| const int32_t *bias_data, const ConvParameter *conv_param, int task_id); | const int32_t *bias_data, const ConvParameter *conv_param, int task_id); | ||||
| @@ -0,0 +1,149 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_INFER_INVALID; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_DepthwiseConv2D; | |||||
| namespace mindspore::kernel { | |||||
| ConvolutionDepthwise3x3CPUKernel::~ConvolutionDepthwise3x3CPUKernel() { | |||||
| if (packed_weight_ != nullptr) { | |||||
| free(packed_weight_); | |||||
| packed_weight_ = nullptr; | |||||
| } | |||||
| if (sliding_ != nullptr) { | |||||
| delete sliding_; | |||||
| sliding_ = nullptr; | |||||
| } | |||||
| } | |||||
| int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() { | |||||
| // init weight: k, h, w, c; k == group == output_channel, c == 1 | |||||
| auto weight_tensor = in_tensors_[kWeightIndex]; | |||||
| auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData()); | |||||
| int channel = weight_tensor->Batch(); | |||||
| int pack_weight_size = weight_tensor->Batch() * weight_tensor->Height() * weight_tensor->Width(); | |||||
| packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float))); | |||||
| if (packed_weight_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| PackWeightKHWToHWKFp32(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), channel); | |||||
| bias_data_ = reinterpret_cast<float *>(malloc(channel * sizeof(float))); | |||||
| if (bias_data_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(bias_data_, 0, channel * sizeof(float)); | |||||
| if (in_tensors_.size() == kInputSize2) { | |||||
| auto bias_tensor = in_tensors_[kBiasIndex]; | |||||
| auto ori_bias = reinterpret_cast<float *>(bias_tensor->MutableData()); | |||||
| memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(float)); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionDepthwise3x3CPUKernel::Init() { | |||||
| sliding_ = new (std::nothrow) SlidingWindowParam; | |||||
| if (sliding_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new sliding window param failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto ret = InitWeightBias(); | |||||
| if (ret != 0) { | |||||
| MS_LOG(ERROR) << "Convolution depthwise 3x3 fp32 InitWeightBias failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| return ReSize(); | |||||
| } | |||||
| int ConvolutionDepthwise3x3CPUKernel::ReSize() { | |||||
| ConvolutionBaseCPUKernel::Init(); | |||||
| InitSlidingParamConvDw(sliding_, conv_param_, conv_param_->input_channel_); | |||||
| conv_param_->thread_num_ = MSMIN(thread_count_, conv_param_->output_h_); | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionDepthwise3x3CPUKernel::Execute(int task_id) { | |||||
| auto buffer = buffer_ + 64 * 10 * 10 * task_id; | |||||
| ConvDw3x3(output_ptr_, buffer, input_ptr_, packed_weight_, reinterpret_cast<float *>(bias_data_), conv_param_, | |||||
| sliding_, task_id); | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvDw3x3Run(void *cdata, int task_id) { | |||||
| auto conv_dw = reinterpret_cast<ConvolutionDepthwise3x3CPUKernel *>(cdata); | |||||
| auto ret = conv_dw->Execute(task_id); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ConvolutionDepthwise3x3Run error task_id[" << task_id << "] error_code[" << ret << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionDepthwise3x3CPUKernel::InitBuffer() { | |||||
| int buffer_size = 64 * 10 * 10 * conv_param_->thread_num_; | |||||
| buffer_ = reinterpret_cast<float *>(context_->allocator->Malloc(buffer_size * sizeof(float))); | |||||
| if (buffer_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionDepthwise3x3CPUKernel::Run() { | |||||
| auto ret = InitBuffer(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Depthwise int8 ReSize error!"; | |||||
| return ret; | |||||
| } | |||||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||||
| input_ptr_ = reinterpret_cast<float *>(input_tensor->data_c()); | |||||
| auto output_tensor = out_tensors_.at(kOutputIndex); | |||||
| output_ptr_ = reinterpret_cast<float *>(output_tensor->data_c()); | |||||
| if (sliding_->top_ > 0 || sliding_->bottom_ < conv_param_->output_h_ || sliding_->left_ > 0 || | |||||
| sliding_->right_ < conv_param_->output_w_) { | |||||
| ConvDw3x3Pad(output_ptr_, input_ptr_, packed_weight_, reinterpret_cast<float *>(bias_data_), conv_param_, sliding_); | |||||
| } | |||||
| ret = ParallelLaunch(this->context_->thread_pool_, ConvDw3x3Run, this, conv_param_->thread_num_); | |||||
| if (ret != RET_OK) { | |||||
| context_->allocator->Free(buffer_); | |||||
| MS_LOG(ERROR) << "ConvDw3x3Run error: error_code[" << ret << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| context_->allocator->Free(buffer_); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "src/runtime/kernel/arm/base/convolution_base.h" | |||||
| #include "nnacl/fp32/conv_depthwise.h" | |||||
| namespace mindspore::kernel { | |||||
| class ConvolutionDepthwise3x3CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| public: | |||||
| ConvolutionDepthwise3x3CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~ConvolutionDepthwise3x3CPUKernel() override; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int InitWeightBias(); | |||||
| int Execute(int task_id); | |||||
| private: | |||||
| int InitBuffer(); | |||||
| SlidingWindowParam *sliding_ = nullptr; | |||||
| float *packed_weight_ = nullptr; | |||||
| float *input_ptr_ = nullptr; | |||||
| float *output_ptr_ = nullptr; | |||||
| float *buffer_ = nullptr; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_ | |||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h" | #include "src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h" | ||||
| #include "src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.h" | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h" | #include "src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h" | ||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| @@ -136,10 +137,24 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> | |||||
| } | } | ||||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | ||||
| kernel::LiteKernel *kernel; | |||||
| if (conv_param->input_channel_ < 32) { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionDepthwiseSWCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| } else { | |||||
| kernel::LiteKernel *kernel = nullptr; | |||||
| if (primitive != nullptr && primitive->GetInferFlag()) { | |||||
| conv_param->input_h_ = inputs[kInputIndex]->Height(); | |||||
| conv_param->input_w_ = inputs[kInputIndex]->Width(); | |||||
| conv_param->input_channel_ = inputs[kInputIndex]->Channel(); | |||||
| conv_param->output_h_ = outputs[kOutputIndex]->Height(); | |||||
| conv_param->output_w_ = outputs[kOutputIndex]->Width(); | |||||
| if (CheckConvDwUse3X3(conv_param) && conv_param->input_channel_ % C4NUM == 0) { | |||||
| #ifdef ENABLE_ARM64 | |||||
| kernel = | |||||
| new (std::nothrow) kernel::ConvolutionDepthwise3x3CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| #endif | |||||
| } | |||||
| if (kernel == nullptr && conv_param->input_channel_ < 32) { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionDepthwiseSWCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| } | |||||
| } | |||||
| if (kernel == nullptr) { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive); | kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| } | } | ||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| @@ -181,7 +181,7 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::Tensor *> | |||||
| conv_param->output_w_ = outputs[kOutputIndex]->Width(); | conv_param->output_w_ = outputs[kOutputIndex]->Width(); | ||||
| } | } | ||||
| auto weight_quant_size = inputs[kWeightIndex]->GetQuantParams().size(); | auto weight_quant_size = inputs[kWeightIndex]->GetQuantParams().size(); | ||||
| if (CheckConvDwInt8Use3X3(conv_param) && weight_quant_size == 1) { | |||||
| if (CheckConvDwUse3X3(conv_param) && conv_param->input_channel_ % C8NUM == 0 && weight_quant_size == 1) { | |||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| kernel = | kernel = | ||||
| new (std::nothrow) kernel::ConvolutionDepthwise3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | new (std::nothrow) kernel::ConvolutionDepthwise3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||