Merge pull request !6102 from ling/srtags/v1.0.0
| @@ -6,9 +6,9 @@ | |||
| .type MatmulInt8Neon64, %function | |||
| #endif | |||
| //void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, | |||
| // const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, | |||
| // int multiplier, int left_shift, int right_shift, int row, int col, int stride); | |||
| //void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, | |||
| // const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, | |||
| // int32_t *right_shift, int row, int col, int stride, int filter_peroc) | |||
| // x0: a(left matrix ptr) | |||
| // x1: b(right matrix ptr) | |||
| @@ -21,31 +21,34 @@ | |||
| // w8: act_min | |||
| // w9: act_max | |||
| // w10: out_zp | |||
| // w11: multiplier | |||
| // w12: left_shift | |||
| // w13: right_shift | |||
| // x11: multiplier | |||
| // x12: left_shift | |||
| // x13: right_shift | |||
| // w14: row | |||
| // w15: col | |||
| // w24: stride | |||
| // w27: filter_peroc | |||
| MatmulInt8Neon64: | |||
| sub sp, sp, #192 | |||
| sub sp, sp, #208 | |||
| st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||
| stp x19, x20, [sp], #16 | |||
| stp x21, x22, [sp], #16 | |||
| stp x23, x24, [sp], #16 | |||
| stp x25, x26, [sp], #16 | |||
| stp x27, x28, [sp], #16 | |||
| ldr w8, [sp] | |||
| ldr w9, [sp, #8] | |||
| ldr w10, [sp, #16] | |||
| ldr w11, [sp, #24] | |||
| ldr w12, [sp, #32] | |||
| ldr w13, [sp, #40] | |||
| ldr x11, [sp, #24] | |||
| ldr x12, [sp, #32] | |||
| ldr x13, [sp, #40] | |||
| ldr w14, [sp, #48] | |||
| ldr w15, [sp, #56] | |||
| ldr w24, [sp, #64] | |||
| ldr w27, [sp, #72] | |||
| mov w17, #4 // sizeof(int8)*4 | |||
| mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 | |||
| @@ -58,7 +61,7 @@ L1: | |||
| mov w16, w3 // reset a row4 counter | |||
| mov w23, w14 // reset a row counter | |||
| mov x17, x0 // reload a ptr | |||
| mov x22, x6 // reload a_sums ptr | |||
| mov x22, x6 // reload a_sums ptr | |||
| L2: | |||
| cmp w16, #0 | |||
| beq End2 | |||
| @@ -167,39 +170,60 @@ End3: | |||
| addp v19.4s, v28.4s, v30.4s | |||
| // Add (Bias+Depth*Za*Zb-Za*Bsums) | |||
| ld1 {v15.4s}, [x19], #16 | |||
| ld1 {v15.4s}, [x19], #16 | |||
| add v16.4s, v16.4s, v15.4s | |||
| add v17.4s, v17.4s, v15.4s | |||
| add v18.4s, v18.4s, v15.4s | |||
| add v19.4s, v19.4s, v15.4s | |||
| // Subtract (Asums*Zb) | |||
| cmp w27, #0 | |||
| beq PerTLoad | |||
| PerCLoad: | |||
| ld1 {v20.4s}, [x6], #16 | |||
| ld1 {v21.4s}, [x6], #16 | |||
| ld1 {v22.4s}, [x6], #16 | |||
| ld1 {v23.4s}, [x6], #16 | |||
| ld1 {v13.4s}, [x12] | |||
| ld1 {v12.4s}, [x11] | |||
| ld1 {v11.4s}, [x13] | |||
| b Apply | |||
| PerTLoad: | |||
| ld1 {v14.4s}, [x22], #16 | |||
| dup v20.4s, v14.s[0] | |||
| dup v21.4s, v14.s[1] | |||
| dup v22.4s, v14.s[2] | |||
| dup v23.4s, v14.s[3] | |||
| ld1 {v14.s}[0], [x12] | |||
| dup v13.4s, v14.s[0] | |||
| ld1 {v14.s}[0], [x11] | |||
| dup v12.4s, v14.s[0] | |||
| ld1 {v14.s}[0], [x13] | |||
| dup v11.4s, v14.s[0] | |||
| b Apply | |||
| Apply: | |||
| // Subtract (Asums*Zb) | |||
| sub v16.4s, v16.4s, v20.4s | |||
| sub v17.4s, v17.4s, v21.4s | |||
| sub v18.4s, v18.4s, v22.4s | |||
| sub v19.4s, v19.4s, v23.4s | |||
| // Apply left shift | |||
| dup v13.4s, w12 | |||
| sqshl v16.4s, v16.4s, v13.4s | |||
| sqshl v17.4s, v17.4s, v13.4s | |||
| sqshl v18.4s, v18.4s, v13.4s | |||
| sqshl v19.4s, v19.4s, v13.4s | |||
| // Apply the fixed-point part of the multiplier. | |||
| dup v12.4s, w11 | |||
| sqrdmulh v16.4s, v16.4s, v12.4s | |||
| sqrdmulh v17.4s, v17.4s, v12.4s | |||
| sqrdmulh v18.4s, v18.4s, v12.4s | |||
| sqrdmulh v19.4s, v19.4s, v12.4s | |||
| // Apply right shift | |||
| dup v11.4s, w13 | |||
| and v20.16b, v11.16b, v16.16b | |||
| sshr v20.4s, v20.4s, #31 | |||
| sqadd v16.4s, v16.4s, v20.4s | |||
| @@ -268,7 +292,7 @@ Write: | |||
| beq WriteCol2 | |||
| cmp w15, #1 | |||
| beq WriteCol1 | |||
| WriteCol4: | |||
| st1 {v15.s}[0], [x2], x24 | |||
| cmp w23, #1 | |||
| @@ -349,7 +373,7 @@ WriteCol1: | |||
| st1 {v15.b}[12], [x2], x24 | |||
| b Endwrite | |||
| Endwrite: | |||
| Endwrite: | |||
| sub w16, w16, #4 // a row4 counter - 4 | |||
| sub w23, w23, #4 // a row counter - 4 | |||
| b L2 | |||
| @@ -361,15 +385,23 @@ End2: | |||
| add x7, x7, #16 // bias ptr + stride | |||
| add x25, x25, #4 // output + stride(4 * sizeof(int8)) | |||
| mov x2, x25 | |||
| cmp w27, #0 | |||
| beq PerTEnd2 | |||
| add x12, x12, #16 | |||
| add x11, x11, #16 | |||
| add x13, x13, #16 | |||
| PerTEnd2: | |||
| b L1 | |||
| End1: | |||
| sub sp, sp, #192 | |||
| sub sp, sp, #208 | |||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||
| ldp x19, x20, [sp], #16 | |||
| ldp x21, x22, [sp], #16 | |||
| ldp x23, x24, [sp], #16 | |||
| ldp x25, x26, [sp], #16 | |||
| ldp x27, x28, [sp], #16 | |||
| ret | |||
| #endif | |||
| @@ -8,7 +8,7 @@ | |||
| //void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4, | |||
| // const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, | |||
| // int multiplier, int left_shift, int right_shift, int row, int col, int stride); | |||
| // int *multiplier, int *left_shift, int *right_shift, int row, int col, int stride, int peroc); | |||
| // x0: a(left matrix ptr) | |||
| // x1: b(right matrix ptr) | |||
| @@ -21,31 +21,34 @@ | |||
| // w8: act_min | |||
| // w9: act_max | |||
| // w10: out_zp | |||
| // w11: multiplier | |||
| // w12: left_shift | |||
| // w13: right_shift | |||
| // x11: multiplier | |||
| // x12: left_shift | |||
| // x13: right_shift | |||
| // w14: row | |||
| // w15: col | |||
| // w24: stride | |||
| // w27: filter_peroc | |||
| MatmulInt8DpNeon64: | |||
| sub sp, sp, #192 | |||
| sub sp, sp, #208 | |||
| st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||
| stp x19, x20, [sp], #16 | |||
| stp x21, x22, [sp], #16 | |||
| stp x23, x24, [sp], #16 | |||
| stp x25, x26, [sp], #16 | |||
| stp x27, x28, [sp], #16 | |||
| ldr w8, [sp] | |||
| ldr w9, [sp, #8] | |||
| ldr w10, [sp, #16] | |||
| ldr w11, [sp, #24] | |||
| ldr w12, [sp, #32] | |||
| ldr w13, [sp, #40] | |||
| ldr x11, [sp, #24] | |||
| ldr x12, [sp, #32] | |||
| ldr x13, [sp, #40] | |||
| ldr w14, [sp, #48] | |||
| ldr w15, [sp, #56] | |||
| ldr w24, [sp, #64] | |||
| ldr w27, [sp, #72] | |||
| mov w17, #8 // sizeof(int8)*8 | |||
| mul w21, w5, w17 // the stride of a/b: sizeof(int8)*8*deep4 | |||
| @@ -226,138 +229,171 @@ End3: | |||
| add v29.4s, v29.4s, v14.4s | |||
| add v31.4s, v31.4s, v14.4s | |||
| cmp w27, #0 | |||
| beq PerTSumLoad | |||
| PerCSumLoad: | |||
| ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x6], #64 | |||
| ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x6], #64 | |||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64 | |||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 | |||
| b ApplySum | |||
| PerTSumLoad: | |||
| ld1 {v14.4s}, [x22], #16 | |||
| ld1 {v15.4s}, [x22], #16 | |||
| dup v0.4s, v14.s[0] | |||
| dup v1.4s, v14.s[0] | |||
| dup v2.4s, v14.s[1] | |||
| dup v3.4s, v14.s[1] | |||
| dup v4.4s, v14.s[2] | |||
| dup v5.4s, v14.s[2] | |||
| dup v6.4s, v14.s[3] | |||
| dup v7.4s, v14.s[3] | |||
| dup v8.4s, v15.s[0] | |||
| dup v9.4s, v15.s[0] | |||
| dup v10.4s, v15.s[1] | |||
| dup v11.4s, v15.s[1] | |||
| dup v12.4s, v15.s[2] | |||
| dup v13.4s, v15.s[2] | |||
| dup v14.4s, v15.s[3] | |||
| dup v15.4s, v14.s[0] | |||
| ApplySum: | |||
| // Subtract (Asums*Zb) | |||
| ld1 {v13.4s}, [x22], #16 | |||
| ld1 {v12.4s}, [x22], #16 | |||
| dup v0.4s, v13.s[0] | |||
| dup v1.4s, v13.s[1] | |||
| dup v2.4s, v13.s[2] | |||
| dup v3.4s, v13.s[3] | |||
| dup v4.4s, v12.s[0] | |||
| dup v5.4s, v12.s[1] | |||
| dup v6.4s, v12.s[2] | |||
| dup v7.4s, v12.s[3] | |||
| sub v16.4s, v16.4s, v0.4s | |||
| sub v17.4s, v17.4s, v0.4s | |||
| sub v18.4s, v18.4s, v1.4s | |||
| sub v19.4s, v19.4s, v1.4s | |||
| sub v20.4s, v20.4s, v2.4s | |||
| sub v21.4s, v21.4s, v2.4s | |||
| sub v22.4s, v22.4s, v3.4s | |||
| sub v23.4s, v23.4s, v3.4s | |||
| sub v24.4s, v24.4s, v4.4s | |||
| sub v25.4s, v25.4s, v4.4s | |||
| sub v26.4s, v26.4s, v5.4s | |||
| sub v27.4s, v27.4s, v5.4s | |||
| sub v28.4s, v28.4s, v6.4s | |||
| sub v29.4s, v29.4s, v6.4s | |||
| sub v30.4s, v30.4s, v7.4s | |||
| sub v31.4s, v31.4s, v7.4s | |||
| sub v17.4s, v17.4s, v1.4s | |||
| sub v18.4s, v18.4s, v2.4s | |||
| sub v19.4s, v19.4s, v3.4s | |||
| sub v20.4s, v20.4s, v4.4s | |||
| sub v21.4s, v21.4s, v5.4s | |||
| sub v22.4s, v22.4s, v6.4s | |||
| sub v23.4s, v23.4s, v7.4s | |||
| sub v24.4s, v24.4s, v8.4s | |||
| sub v25.4s, v25.4s, v9.4s | |||
| sub v26.4s, v26.4s, v10.4s | |||
| sub v27.4s, v27.4s, v11.4s | |||
| sub v28.4s, v28.4s, v12.4s | |||
| sub v29.4s, v29.4s, v13.4s | |||
| sub v30.4s, v30.4s, v14.4s | |||
| sub v31.4s, v31.4s, v15.4s | |||
| cmp w27, #0 | |||
| beq PerTRoundLoad | |||
| PerCRoundLoad: | |||
| ld1 {v8.4s, v9.4s}, [x12] | |||
| ld1 {v10.4s, v11.4s}, [x11] | |||
| ld1 {v12.4s, v13.4s}, [x13] | |||
| b ApplyRound | |||
| PerTRoundLoad: | |||
| ld1 {v14.s}[0], [x12] | |||
| dup v8.4s, v14.s[0] | |||
| dup v9.4s, v14.s[0] | |||
| ld1 {v14.s}[0], [x11] | |||
| dup v10.4s, v14.s[0] | |||
| dup v11.4s, v14.s[0] | |||
| ld1 {v14.s}[0], [x13] | |||
| dup v12.4s, v14.s[0] | |||
| dup v13.4s, v14.s[0] | |||
| ApplyRound: | |||
| // Apply left shift | |||
| dup v11.4s, w12 | |||
| sqshl v16.4s, v16.4s, v11.4s | |||
| sqshl v17.4s, v17.4s, v11.4s | |||
| sqshl v18.4s, v18.4s, v11.4s | |||
| sqshl v19.4s, v19.4s, v11.4s | |||
| sqshl v20.4s, v20.4s, v11.4s | |||
| sqshl v21.4s, v21.4s, v11.4s | |||
| sqshl v22.4s, v22.4s, v11.4s | |||
| sqshl v23.4s, v23.4s, v11.4s | |||
| sqshl v24.4s, v24.4s, v11.4s | |||
| sqshl v25.4s, v25.4s, v11.4s | |||
| sqshl v26.4s, v26.4s, v11.4s | |||
| sqshl v27.4s, v27.4s, v11.4s | |||
| sqshl v28.4s, v28.4s, v11.4s | |||
| sqshl v29.4s, v29.4s, v11.4s | |||
| sqshl v30.4s, v30.4s, v11.4s | |||
| sqshl v31.4s, v31.4s, v11.4s | |||
| sqshl v16.4s, v16.4s, v8.4s | |||
| sqshl v17.4s, v17.4s, v9.4s | |||
| sqshl v18.4s, v18.4s, v8.4s | |||
| sqshl v19.4s, v19.4s, v9.4s | |||
| sqshl v20.4s, v20.4s, v8.4s | |||
| sqshl v21.4s, v21.4s, v9.4s | |||
| sqshl v22.4s, v22.4s, v8.4s | |||
| sqshl v23.4s, v23.4s, v9.4s | |||
| sqshl v24.4s, v24.4s, v8.4s | |||
| sqshl v25.4s, v25.4s, v9.4s | |||
| sqshl v26.4s, v26.4s, v8.4s | |||
| sqshl v27.4s, v27.4s, v9.4s | |||
| sqshl v28.4s, v28.4s, v8.4s | |||
| sqshl v29.4s, v29.4s, v9.4s | |||
| sqshl v30.4s, v30.4s, v8.4s | |||
| sqshl v31.4s, v31.4s, v9.4s | |||
| // Apply the fixed-point part of the multiplier. | |||
| dup v10.4s, w11 | |||
| sqrdmulh v16.4s, v16.4s, v10.4s | |||
| sqrdmulh v17.4s, v17.4s, v10.4s | |||
| sqrdmulh v17.4s, v17.4s, v11.4s | |||
| sqrdmulh v18.4s, v18.4s, v10.4s | |||
| sqrdmulh v19.4s, v19.4s, v10.4s | |||
| sqrdmulh v19.4s, v19.4s, v11.4s | |||
| sqrdmulh v20.4s, v20.4s, v10.4s | |||
| sqrdmulh v21.4s, v21.4s, v10.4s | |||
| sqrdmulh v21.4s, v21.4s, v11.4s | |||
| sqrdmulh v22.4s, v22.4s, v10.4s | |||
| sqrdmulh v23.4s, v23.4s, v10.4s | |||
| sqrdmulh v23.4s, v23.4s, v11.4s | |||
| sqrdmulh v24.4s, v24.4s, v10.4s | |||
| sqrdmulh v25.4s, v25.4s, v10.4s | |||
| sqrdmulh v25.4s, v25.4s, v11.4s | |||
| sqrdmulh v26.4s, v26.4s, v10.4s | |||
| sqrdmulh v27.4s, v27.4s, v10.4s | |||
| sqrdmulh v27.4s, v27.4s, v11.4s | |||
| sqrdmulh v28.4s, v28.4s, v10.4s | |||
| sqrdmulh v29.4s, v29.4s, v10.4s | |||
| sqrdmulh v29.4s, v29.4s, v11.4s | |||
| sqrdmulh v30.4s, v30.4s, v10.4s | |||
| sqrdmulh v31.4s, v31.4s, v10.4s | |||
| sqrdmulh v31.4s, v31.4s, v11.4s | |||
| // Apply right shift | |||
| dup v9.4s, w13 | |||
| and v0.16b, v9.16b, v16.16b | |||
| and v0.16b, v12.16b, v16.16b | |||
| sshr v0.4s, v0.4s, #31 | |||
| sqadd v16.4s, v16.4s, v0.4s | |||
| srshl v16.4s, v16.4s, v9.4s | |||
| and v1.16b, v9.16b, v17.16b | |||
| srshl v16.4s, v16.4s, v12.4s | |||
| and v1.16b, v13.16b, v17.16b | |||
| sshr v1.4s, v1.4s, #31 | |||
| sqadd v17.4s, v17.4s, v1.4s | |||
| srshl v17.4s, v17.4s, v9.4s | |||
| and v2.16b, v9.16b, v18.16b | |||
| srshl v17.4s, v17.4s, v13.4s | |||
| and v2.16b, v12.16b, v18.16b | |||
| sshr v2.4s, v2.4s, #31 | |||
| sqadd v18.4s, v18.4s, v2.4s | |||
| srshl v18.4s, v18.4s, v9.4s | |||
| and v3.16b, v9.16b, v19.16b | |||
| srshl v18.4s, v18.4s, v12.4s | |||
| and v3.16b, v13.16b, v19.16b | |||
| sshr v3.4s, v3.4s, #31 | |||
| sqadd v19.4s, v19.4s, v3.4s | |||
| srshl v19.4s, v19.4s, v9.4s | |||
| and v0.16b, v9.16b, v20.16b | |||
| srshl v19.4s, v19.4s, v13.4s | |||
| and v0.16b, v12.16b, v20.16b | |||
| sshr v0.4s, v0.4s, #31 | |||
| sqadd v20.4s, v20.4s, v0.4s | |||
| srshl v20.4s, v20.4s, v9.4s | |||
| and v1.16b, v9.16b, v21.16b | |||
| srshl v20.4s, v20.4s, v12.4s | |||
| and v1.16b, v13.16b, v21.16b | |||
| sshr v1.4s, v1.4s, #31 | |||
| sqadd v21.4s, v21.4s, v1.4s | |||
| srshl v21.4s, v21.4s, v9.4s | |||
| and v2.16b, v9.16b, v22.16b | |||
| srshl v21.4s, v21.4s, v13.4s | |||
| and v2.16b, v12.16b, v22.16b | |||
| sshr v2.4s, v2.4s, #31 | |||
| sqadd v22.4s, v22.4s, v2.4s | |||
| srshl v22.4s, v22.4s, v9.4s | |||
| and v3.16b, v9.16b, v23.16b | |||
| srshl v22.4s, v22.4s, v12.4s | |||
| and v3.16b, v13.16b, v23.16b | |||
| sshr v3.4s, v3.4s, #31 | |||
| sqadd v23.4s, v23.4s, v3.4s | |||
| srshl v23.4s, v23.4s, v9.4s | |||
| and v0.16b, v9.16b, v24.16b | |||
| srshl v23.4s, v23.4s, v13.4s | |||
| and v0.16b, v12.16b, v24.16b | |||
| sshr v0.4s, v0.4s, #31 | |||
| sqadd v24.4s, v24.4s, v0.4s | |||
| srshl v24.4s, v24.4s, v9.4s | |||
| and v1.16b, v9.16b, v25.16b | |||
| srshl v24.4s, v24.4s, v12.4s | |||
| and v1.16b, v13.16b, v25.16b | |||
| sshr v1.4s, v1.4s, #31 | |||
| sqadd v25.4s, v25.4s, v1.4s | |||
| srshl v25.4s, v25.4s, v9.4s | |||
| and v2.16b, v9.16b, v26.16b | |||
| srshl v25.4s, v25.4s, v13.4s | |||
| and v2.16b, v12.16b, v26.16b | |||
| sshr v2.4s, v2.4s, #31 | |||
| sqadd v26.4s, v26.4s, v2.4s | |||
| srshl v26.4s, v26.4s, v9.4s | |||
| and v3.16b, v9.16b, v27.16b | |||
| srshl v26.4s, v26.4s, v12.4s | |||
| and v3.16b, v13.16b, v27.16b | |||
| sshr v3.4s, v3.4s, #31 | |||
| sqadd v27.4s, v27.4s, v3.4s | |||
| srshl v27.4s, v27.4s, v9.4s | |||
| and v0.16b, v9.16b, v28.16b | |||
| srshl v27.4s, v27.4s, v13.4s | |||
| and v0.16b, v12.16b, v28.16b | |||
| sshr v0.4s, v0.4s, #31 | |||
| sqadd v28.4s, v28.4s, v0.4s | |||
| srshl v28.4s, v28.4s, v9.4s | |||
| and v1.16b, v9.16b, v29.16b | |||
| srshl v28.4s, v28.4s, v12.4s | |||
| and v1.16b, v13.16b, v29.16b | |||
| sshr v1.4s, v1.4s, #31 | |||
| sqadd v29.4s, v29.4s, v1.4s | |||
| srshl v29.4s, v29.4s, v9.4s | |||
| and v2.16b, v9.16b, v30.16b | |||
| srshl v29.4s, v29.4s, v13.4s | |||
| and v2.16b, v12.16b, v30.16b | |||
| sshr v2.4s, v2.4s, #31 | |||
| sqadd v30.4s, v30.4s, v2.4s | |||
| srshl v30.4s, v30.4s, v9.4s | |||
| and v3.16b, v9.16b, v31.16b | |||
| srshl v30.4s, v30.4s, v12.4s | |||
| and v3.16b, v13.16b, v31.16b | |||
| sshr v3.4s, v3.4s, #31 | |||
| sqadd v31.4s, v31.4s, v3.4s | |||
| srshl v31.4s, v31.4s, v9.4s | |||
| srshl v31.4s, v31.4s, v13.4s | |||
| // Add the destination zero point | |||
| dup v8.4s, w10 | |||
| @@ -793,15 +829,23 @@ End2: | |||
| add x7, x7, #32 // bias ptr + stride | |||
| add x25, x25, #8 // output + stride(8 * sizeof(int8)) | |||
| mov x2, x25 | |||
| cmp w27, #0 | |||
| beq PerTEnd2 | |||
| add x12, x12, #32 | |||
| add x11, x11, #32 | |||
| add x13, x13, #32 | |||
| PerTEnd2: | |||
| b L1 | |||
| End1: | |||
| sub sp, sp, #192 | |||
| sub sp, sp, #208 | |||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||
| ldp x19, x20, [sp], #16 | |||
| ldp x21, x22, [sp], #16 | |||
| ldp x23, x24, [sp], #16 | |||
| ldp x25, x26, [sp], #16 | |||
| ldp x27, x28, [sp], #16 | |||
| ret | |||
| #endif | |||
| @@ -385,12 +385,302 @@ void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t * | |||
| int8_t *pack_r = packed_input; | |||
| int32_t *input_sum_r = input_sum; | |||
| /* per layer */ | |||
| for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { | |||
| const int8_t *src_ic = src_r; | |||
| int8_t *pack_ic = pack_r; | |||
| int32_t *input_sum_oc = input_sum_r; | |||
| #ifdef ENABLE_ARM64 | |||
| size_t src_stride = input_channel; | |||
| size_t ic_4res = input_channel - ic_4div; | |||
| size_t input_sum_stride = inputsum_stride * 4 - C8NUM * C8NUM * 4; | |||
| asm volatile( | |||
| "dup v16.4s, wzr \n" | |||
| "dup v17.4s, wzr \n" | |||
| "mov x10, %[src_ic] \n" | |||
| "mov x11, %[pack_ic] \n" | |||
| "mov x0, #0 \n" | |||
| "1: \n" | |||
| "cmp x0, %[ic_4div] \n" | |||
| "add x0, x0, #4\n" | |||
| "mov x12, x10 \n" | |||
| "add x10, x10, #4\n" | |||
| "blt 2f \n" | |||
| "cmp %[ic_4res], #0\n" | |||
| "beq 6f \n" | |||
| "cmp %[ic_4res], #1\n" | |||
| "beq 3f \n" | |||
| "cmp %[ic_4res], #2\n" | |||
| "beq 4f \n" | |||
| "cmp %[ic_4res], #3\n" | |||
| "beq 5f \n" | |||
| "2: \n" | |||
| "ld1 {v0.s}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.s}[1], [x12], %[src_stride]\n" | |||
| "ld1 {v0.s}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v0.s}[3], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[1], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[3], [x12], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v16.4s, v16.4s, v0.4s \n" | |||
| "add v17.4s, v17.4s, v1.4s \n" | |||
| "b 1b \n" | |||
| "3: \n" /* col res 1 */ | |||
| "dup v0.4s, wzr \n" | |||
| "dup v1.4s, wzr \n" | |||
| "ld1 {v0.b}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[8], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[12], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[8], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[12], [x12], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v16.4s, v16.4s, v0.4s \n" | |||
| "add v17.4s, v17.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "4: \n" /* col res 2 */ | |||
| "dup v0.4s, wzr \n" | |||
| "dup v1.4s, wzr \n" | |||
| "ld1 {v0.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v0.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v0.h}[6], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[6], [x12], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v16.4s, v16.4s, v0.4s \n" | |||
| "add v17.4s, v17.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "5: \n" /* col res 3 */ | |||
| "dup v0.4s, wzr \n" | |||
| "dup v1.4s, wzr \n" | |||
| "add x13, x12, #2 \n" | |||
| "ld1 {v0.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[2], [x13], %[src_stride]\n" | |||
| "ld1 {v0.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[6], [x13], %[src_stride]\n" | |||
| "ld1 {v0.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[10], [x13], %[src_stride]\n" | |||
| "ld1 {v0.h}[6], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[14], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[2], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[6], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[10], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[6], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[14], [x13], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v16.4s, v16.4s, v0.4s \n" | |||
| "add v17.4s, v17.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "6: \n" | |||
| "dup v0.4s, v16.s[0] \n" | |||
| "dup v1.4s, v16.s[1] \n" | |||
| "dup v2.4s, v16.s[2] \n" | |||
| "dup v3.4s, v16.s[3] \n" | |||
| "dup v4.4s, v17.s[0] \n" | |||
| "dup v5.4s, v17.s[1] \n" | |||
| "dup v6.4s, v17.s[2] \n" | |||
| "dup v7.4s, v17.s[3] \n" | |||
| "mov x4, #0 \n" | |||
| "mov x10, %[filter_zp] \n" | |||
| "mov x11, %[input_sum_oc] \n" | |||
| "7: \n" | |||
| "cmp x4, %[oc_8div] \n" | |||
| "beq 8f \n" | |||
| "add x4, x4, #8\n" | |||
| "ld1 {v16.4s}, [x10], #16\n" | |||
| "ld1 {v17.4s}, [x10], #16\n" | |||
| "mul v18.4s, v16.4s, v0.4s \n" | |||
| "mul v19.4s, v17.4s, v0.4s \n" | |||
| "st1 {v18.4s}, [x11], #16 \n" | |||
| "st1 {v19.4s}, [x11], #16 \n" | |||
| "mul v20.4s, v16.4s, v1.4s \n" | |||
| "mul v21.4s, v17.4s, v1.4s \n" | |||
| "st1 {v20.4s}, [x11], #16 \n" | |||
| "st1 {v21.4s}, [x11], #16 \n" | |||
| "mul v22.4s, v16.4s, v2.4s \n" | |||
| "mul v23.4s, v17.4s, v2.4s \n" | |||
| "st1 {v22.4s}, [x11], #16 \n" | |||
| "st1 {v23.4s}, [x11], #16 \n" | |||
| "mul v24.4s, v16.4s, v3.4s \n" | |||
| "mul v25.4s, v17.4s, v3.4s \n" | |||
| "st1 {v24.4s}, [x11], #16 \n" | |||
| "st1 {v25.4s}, [x11], #16 \n" | |||
| "mul v18.4s, v16.4s, v4.4s \n" | |||
| "mul v19.4s, v17.4s, v4.4s \n" | |||
| "st1 {v18.4s}, [x11], #16 \n" | |||
| "st1 {v19.4s}, [x11], #16 \n" | |||
| "mul v20.4s, v16.4s, v5.4s \n" | |||
| "mul v21.4s, v17.4s, v5.4s \n" | |||
| "st1 {v20.4s}, [x11], #16 \n" | |||
| "st1 {v21.4s}, [x11], #16 \n" | |||
| "mul v22.4s, v16.4s, v6.4s \n" | |||
| "mul v23.4s, v17.4s, v6.4s \n" | |||
| "st1 {v22.4s}, [x11], #16 \n" | |||
| "st1 {v23.4s}, [x11], #16 \n" | |||
| "mul v24.4s, v16.4s, v7.4s \n" | |||
| "mul v25.4s, v17.4s, v7.4s \n" | |||
| "st1 {v24.4s}, [x11], #16 \n" | |||
| "st1 {v25.4s}, [x11], #16 \n" | |||
| "add x11, x11, %[input_sum_stride] \n" | |||
| "b 7b \n" | |||
| "8: \n" | |||
| "cmp %[oc_8res], #0\n" | |||
| "beq 17f \n" | |||
| "dup v16.4s, wzr \n" | |||
| "dup v17.4s, wzr \n" | |||
| "cmp %[oc_8res], #1\n" | |||
| "beq 9f \n" | |||
| "cmp %[oc_8res], #2\n" | |||
| "beq 10f \n" | |||
| "cmp %[oc_8res], #3\n" | |||
| "beq 11f \n" | |||
| "cmp %[oc_8res], #4\n" | |||
| "beq 12f \n" | |||
| "cmp %[oc_8res], #5\n" | |||
| "beq 13f \n" | |||
| "cmp %[oc_8res], #6\n" | |||
| "beq 14f \n" | |||
| "cmp %[oc_8res], #7\n" | |||
| "beq 15f \n" | |||
| "9: \n" | |||
| "ld1 {v16.s}[0], [x10] \n" | |||
| "b 16f \n" | |||
| "10: \n" | |||
| "ld1 {v16.h}[0], [x10] \n" | |||
| "b 16f \n" | |||
| "11: \n" | |||
| "ld1 {v16.h}[0], [x10] \n" | |||
| "add x10, x10, #8 \n" | |||
| "ld1 {v16.s}[2], [x10] \n" | |||
| "b 16f \n" | |||
| "12: \n" | |||
| "ld1 {v16.4s}, [x10] \n" | |||
| "b 16f \n" | |||
| "13: \n" | |||
| "ld1 {v16.4s}, [x10], #16\n" | |||
| "ld1 {v17.s}[0], [x10] \n" | |||
| "b 16f \n" | |||
| "14: \n" | |||
| "ld1 {v16.4s}, [x10], #16\n" | |||
| "ld1 {v17.h}[0], [x10] \n" | |||
| "b 16f \n" | |||
| "15: \n" | |||
| "ld1 {v16.4s}, [x10], #16\n" | |||
| "ld1 {v17.h}[0], [x10] \n" | |||
| "add x10, x10, #8 \n" | |||
| "ld1 {v17.s}[2], [x10] \n" | |||
| "b 16f \n" | |||
| "16: \n" | |||
| "mul v18.4s, v16.4s, v0.4s \n" | |||
| "mul v19.4s, v17.4s, v0.4s \n" | |||
| "mul v20.4s, v16.4s, v1.4s \n" | |||
| "mul v21.4s, v17.4s, v1.4s \n" | |||
| "mul v22.4s, v16.4s, v2.4s \n" | |||
| "mul v23.4s, v17.4s, v2.4s \n" | |||
| "mul v24.4s, v16.4s, v3.4s \n" | |||
| "mul v25.4s, v17.4s, v3.4s \n" | |||
| "st1 {v18.4s}, [x11], #16 \n" | |||
| "st1 {v19.4s}, [x11], #16 \n" | |||
| "st1 {v20.4s}, [x11], #16 \n" | |||
| "st1 {v21.4s}, [x11], #16 \n" | |||
| "st1 {v22.4s}, [x11], #16 \n" | |||
| "st1 {v23.4s}, [x11], #16 \n" | |||
| "st1 {v24.4s}, [x11], #16 \n" | |||
| "st1 {v25.4s}, [x11], #16 \n" | |||
| "mul v18.4s, v16.4s, v4.4s \n" | |||
| "mul v19.4s, v17.4s, v4.4s \n" | |||
| "mul v20.4s, v16.4s, v5.4s \n" | |||
| "mul v21.4s, v17.4s, v5.4s \n" | |||
| "mul v22.4s, v16.4s, v6.4s \n" | |||
| "mul v23.4s, v17.4s, v6.4s \n" | |||
| "mul v24.4s, v16.4s, v7.4s \n" | |||
| "mul v25.4s, v17.4s, v7.4s \n" | |||
| "st1 {v18.4s}, [x11], #16 \n" | |||
| "st1 {v19.4s}, [x11], #16 \n" | |||
| "st1 {v20.4s}, [x11], #16 \n" | |||
| "st1 {v21.4s}, [x11], #16 \n" | |||
| "st1 {v22.4s}, [x11], #16 \n" | |||
| "st1 {v23.4s}, [x11], #16 \n" | |||
| "st1 {v24.4s}, [x11], #16 \n" | |||
| "st1 {v25.4s}, [x11], #16 \n" | |||
| "17: \n" | |||
| : | |||
| : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ filter_zp ] "r"(filter_zp), | |||
| [ input_sum_oc ] "r"(input_sum_oc), [ input_sum_stride ] "r"(input_sum_stride), [ src_stride ] "r"(src_stride), | |||
| [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ oc_8div ] "r"(oc_8div), [ oc_8res ] "r"(oc_8res) | |||
| : "x0", "x1", "x4", "x9", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", | |||
| "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25"); | |||
| #else | |||
| int32_t tmp_sum_value[8] = {0}; | |||
| for (int ici = 0; ici < ic_4div; ici += C4NUM) { | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| @@ -440,7 +730,7 @@ void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t * | |||
| } | |||
| } | |||
| } /* oc8 res done */ | |||
| #endif | |||
| src_r += input_channel * C8NUM; | |||
| pack_r += ic4 * C8NUM; | |||
| input_sum_r += C8NUM * C8NUM; | |||
| @@ -520,9 +810,9 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i | |||
| size_t src_stride = input_channel; | |||
| size_t ic_4res = input_channel - ic_4div; | |||
| asm volatile( | |||
| "dup v10.4s, wzr \n" | |||
| "dup v11.4s, wzr \n" | |||
| "mov x20, %[input_sum_r] \n" | |||
| "dup v16.4s, wzr \n" | |||
| "dup v17.4s, wzr \n" | |||
| "mov x14, %[input_sum_r] \n" | |||
| "dup v20.4s, %w[filter_zp] \n" | |||
| "mov x10, %[src_ic] \n" | |||
| @@ -563,8 +853,8 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v10.4s, v10.4s, v0.4s \n" | |||
| "add v11.4s, v11.4s, v1.4s \n" | |||
| "add v16.4s, v16.4s, v0.4s \n" | |||
| "add v17.4s, v17.4s, v1.4s \n" | |||
| "b 1b \n" | |||
| "3: \n" /* col res 1 */ | |||
| @@ -586,8 +876,8 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v10.4s, v10.4s, v0.4s \n" | |||
| "add v11.4s, v11.4s, v1.4s \n" | |||
| "add v16.4s, v16.4s, v0.4s \n" | |||
| "add v17.4s, v17.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "4: \n" /* col res 2 */ | |||
| @@ -609,8 +899,8 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v10.4s, v10.4s, v0.4s \n" | |||
| "add v11.4s, v11.4s, v1.4s \n" | |||
| "add v16.4s, v16.4s, v0.4s \n" | |||
| "add v17.4s, v17.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "5: \n" /* col res 3 */ | |||
| @@ -641,21 +931,21 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v10.4s, v10.4s, v0.4s \n" | |||
| "add v11.4s, v11.4s, v1.4s \n" | |||
| "add v16.4s, v16.4s, v0.4s \n" | |||
| "add v17.4s, v17.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "6: \n" | |||
| "mul v10.4s, v10.4s, v20.4s \n" | |||
| "mul v11.4s, v11.4s, v20.4s \n" | |||
| "mul v16.4s, v16.4s, v20.4s \n" | |||
| "mul v17.4s, v17.4s, v20.4s \n" | |||
| "st1 {v10.4s}, [x20], #16 \n" | |||
| "st1 {v11.4s}, [x20], #16 \n" | |||
| "st1 {v16.4s}, [x14], #16 \n" | |||
| "st1 {v17.4s}, [x14], #16 \n" | |||
| : | |||
| : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r), | |||
| [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp) | |||
| : "x0", "x1", "x10", "x11", "x12", "x13", "x20", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v11", | |||
| : "x0", "x1", "x10", "x11", "x12", "x13", "x14", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", | |||
| "v20"); | |||
| #else | |||
| int32_t tmp_sum_value[8] = {0}; | |||
| @@ -728,10 +1018,10 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i | |||
| void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||
| const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, | |||
| int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func) { | |||
| int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; | |||
| matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias, | |||
| left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | |||
| conv_param->conv_quant_arg_.filter_arg_num_ != 1); | |||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc); | |||
| return; | |||
| } | |||
| @@ -756,24 +1046,17 @@ void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, i | |||
| void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||
| const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, | |||
| int32_t *multiplier, ConvParameter *conv_param) { | |||
| if (conv_param->conv_quant_arg_.filter_arg_num_ > 1) { | |||
| return MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, | |||
| bias, left_shift, right_shift, multiplier, | |||
| conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | |||
| conv_param->conv_quant_arg_.filter_arg_num_ != 1); | |||
| } | |||
| int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; | |||
| #ifdef ENABLE_ARM64 | |||
| MatmulInt8Neon64(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum, | |||
| bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | |||
| conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||
| conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], | |||
| conv_param->conv_quant_arg_.right_shift_[0], row, col, conv_param->output_channel_); | |||
| conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, row, col, | |||
| conv_param->output_channel_, is_per_oc); | |||
| #else | |||
| MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, | |||
| left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | |||
| conv_param->conv_quant_arg_.filter_arg_num_ != 1); | |||
| is_per_oc); | |||
| #endif | |||
| return; | |||
| } | |||
| @@ -269,7 +269,7 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, | |||
| void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | |||
| size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | |||
| int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, | |||
| bool per_channel) { | |||
| size_t per_channel) { | |||
| /* row8x4-major * row4x8-major => (int8)row-major */ | |||
| for (int r = 0; r < row; r++) { | |||
| for (int c = 0; c < col; c++) { | |||
| @@ -39,7 +39,7 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col); | |||
| void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | |||
| size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | |||
| int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, | |||
| bool per_channel); | |||
| size_t per_channel); | |||
| void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | |||
| void RowMajor2Row4x8MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | |||
| void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16); | |||
| @@ -59,8 +59,8 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, | |||
| #ifdef ENABLE_ARM64 | |||
| void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, | |||
| const int *bias, int act_min, int act_max, int out_zp, int multiplier, int left_shift, | |||
| int right_shift, int row, int col, int stride); | |||
| const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, | |||
| int32_t *right_shift, int row, int col, int stride, int filter_peroc); | |||
| void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, | |||
| const int *input_sum, const int *bias); | |||
| @@ -25,7 +25,7 @@ typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, i | |||
| typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | |||
| size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | |||
| int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, | |||
| int32_t maxi, bool per_channel); | |||
| int32_t maxi, size_t per_channel); | |||
| typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col); | |||
| @@ -30,8 +30,9 @@ extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_ | |||
| extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, | |||
| const int *input_sum, const int *bias); | |||
| extern void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4, | |||
| const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int multiplier, | |||
| int left_shift, int right_shift, int row, int col, int stride); | |||
| const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, | |||
| int *multiplier, int *left_shift, int *right_shift, int row, int col, int stride, | |||
| size_t peroc); | |||
| #ifdef __cplusplus | |||
| } | |||
| @@ -55,8 +56,8 @@ void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, i | |||
| void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | |||
| size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | |||
| int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, | |||
| int32_t maxi, bool per_channel) { | |||
| int32_t maxi, size_t per_channel) { | |||
| return MatmulInt8DpNeon64(a, b, dst, UP_ROUND(row, 8), UP_ROUND(col, 8), deep_4, input_sum, bias, mini, maxi, | |||
| output_zp, multiplier[0], left_shift[0], right_shift[0], row, col, stride); | |||
| output_zp, multiplier, left_shift, right_shift, row, col, stride, per_channel); | |||
| } | |||
| #endif | |||
| @@ -273,13 +273,13 @@ void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, i | |||
| "mov x11, %[input_sum] \n" | |||
| "mov x15, %[filter_zp_ptr] \n" | |||
| "mov x0, #0 \n" // row 4 count | |||
| "mov x0, #0 \n" | |||
| "1: \n" | |||
| "cmp x0, %[hw4] \n" | |||
| "beq 11f \n" | |||
| "add x0, x0, #4\n" | |||
| "dup v10.4s, wzr \n" | |||
| "mov x2, #0 \n" // input deep count | |||
| "mov x2, #0 \n" | |||
| "mov x16, x15 \n" | |||
| "2: \n" | |||
| @@ -313,9 +313,9 @@ void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, i | |||
| "b 2b \n" | |||
| "3: \n" | |||
| "mov x12, x11 \n" // tmp inputsm inputsum hw | |||
| "mov x12, x11 \n" | |||
| "add x11, x11, #64 \n" | |||
| "mov x4, #0 \n" // oc count | |||
| "mov x4, #0 \n" | |||
| "dup v1.4s, v10.s[0] \n" | |||
| "dup v2.4s, v10.s[1] \n" | |||
| @@ -46,6 +46,18 @@ Convolution1x1Int8CPUKernel::~Convolution1x1Int8CPUKernel() { | |||
| free(filter_zp_ptr_); | |||
| filter_zp_ptr_ = nullptr; | |||
| } | |||
| if (filter_peroc_ && left_shift_ != nullptr) { | |||
| free(left_shift_); | |||
| left_shift_ = nullptr; | |||
| } | |||
| if (filter_peroc_ && right_shift_ != nullptr) { | |||
| free(right_shift_); | |||
| right_shift_ = nullptr; | |||
| } | |||
| if (filter_peroc_ && multiplier_ != nullptr) { | |||
| free(multiplier_); | |||
| multiplier_ = nullptr; | |||
| } | |||
| FreeResizeBuf(); | |||
| FreeQuantParam(); | |||
| } | |||
| @@ -59,7 +71,7 @@ void Convolution1x1Int8CPUKernel::FreeResizeBuf() { | |||
| } | |||
| void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { | |||
| support_optimize_ = false; | |||
| support_optimize_ = true; | |||
| matmul_func_ = MatMulInt8_8x8_r; | |||
| #ifdef ENABLE_ARM64 | |||
| void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; | |||
| @@ -78,10 +90,6 @@ void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { | |||
| support_optimize_ = false; | |||
| matmul_func_ = nullptr; | |||
| } | |||
| if (filter_peroc_) { | |||
| support_optimize_ = false; | |||
| } | |||
| #endif | |||
| return; | |||
| } | |||
| @@ -109,6 +117,26 @@ int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channe | |||
| for (int fi = 0; fi < output_channel; fi++) { | |||
| filter_zp_ptr_[fi] = conv_param_->conv_quant_arg_.filter_quant_args_[fi].zp_; | |||
| } | |||
| int up_round_oc_size = support_optimize_ ? UP_ROUND(output_channel, C8NUM) : UP_ROUND(output_channel, C4NUM); | |||
| left_shift_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t))); | |||
| if (left_shift_ == nullptr) { | |||
| return RET_ERROR; | |||
| } | |||
| memset(left_shift_, 0, up_round_oc_size * sizeof(int32_t)); | |||
| memcpy(left_shift_, conv_param_->conv_quant_arg_.left_shift_, output_channel * sizeof(int32_t)); | |||
| right_shift_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t))); | |||
| if (right_shift_ == nullptr) { | |||
| return RET_ERROR; | |||
| } | |||
| memset(right_shift_, 0, up_round_oc_size * sizeof(int32_t)); | |||
| memcpy(right_shift_, conv_param_->conv_quant_arg_.right_shift_, output_channel * sizeof(int32_t)); | |||
| multiplier_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t))); | |||
| if (multiplier_ == nullptr) { | |||
| return RET_ERROR; | |||
| } | |||
| memset(multiplier_, 0, up_round_oc_size * sizeof(int32_t)); | |||
| memcpy(multiplier_, conv_param_->conv_quant_arg_.quant_multiplier_, output_channel * sizeof(int32_t)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -328,9 +356,9 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { | |||
| } | |||
| if (filter_peroc_) { | |||
| cur_input_sum = input_sum_ + task_id * matmul_param_->row_8_ * thread_stride_ * C8NUM; | |||
| cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C8NUM; | |||
| cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C8NUM; | |||
| cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C8NUM; | |||
| cur_left_shift = left_shift_ + task_id * thread_stride_ * C8NUM; | |||
| cur_right_shift = right_shift_ + task_id * thread_stride_ * C8NUM; | |||
| cur_multiplier = multiplier_ + task_id * thread_stride_ * C8NUM; | |||
| } | |||
| Conv1x1Int8Opt(packed_input_, packed_weight_ + task_id * thread_stride_ * C8NUM * matmul_param_->deep_4_, | |||
| output_ptr_ + task_id * thread_stride_ * C8NUM, cur_input_sum, | |||
| @@ -346,9 +374,9 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { | |||
| } | |||
| if (filter_peroc_) { | |||
| cur_input_sum = input_sum_ + task_id * matmul_param_->row_4_ * thread_stride_ * C4NUM; | |||
| cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C4NUM; | |||
| cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C4NUM; | |||
| cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C4NUM; | |||
| cur_left_shift = left_shift_ + task_id * thread_stride_ * C4NUM; | |||
| cur_right_shift = right_shift_ + task_id * thread_stride_ * C4NUM; | |||
| cur_multiplier = multiplier_ + task_id * thread_stride_ * C4NUM; | |||
| } | |||
| Conv1x1Int8(packed_input_, packed_weight_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_, | |||
| output_ptr_ + task_id * thread_stride_ * C4NUM, cur_input_sum, | |||
| @@ -58,8 +58,11 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel { | |||
| int InitBiasByzp(void *src_weight, int input_channel, int output_channel); | |||
| private: | |||
| int32_t *input_sum_ = nullptr; /* per-channel: oc4 format */ | |||
| int32_t *filter_zp_ptr_ = nullptr; /* oc - per - channel */ | |||
| int32_t *input_sum_ = nullptr; /* per-oc: oc4 format */ | |||
| int32_t *filter_zp_ptr_ = nullptr; /* per-oc */ | |||
| int32_t *left_shift_ = nullptr; /* per-oc up round */ | |||
| int32_t *right_shift_ = nullptr; /* per-oc up round */ | |||
| int32_t *multiplier_ = nullptr; /* per-oc up round */ | |||
| int8_t *packed_weight_ = nullptr; | |||
| int8_t *packed_input_ = nullptr; | |||
| int8_t *input_ptr_ = nullptr; | |||
| @@ -108,8 +108,8 @@ int FullconnectionInt8CPUKernel::RunImpl(int task_id) { | |||
| auto cur_c = output_ptr + task_id * thread_stride_ * C4NUM; | |||
| #ifdef ENABLE_ARM64 | |||
| MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, q.out_act_min, | |||
| q.out_act_max, q.output.zp_, q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res, | |||
| p->col_ * sizeof(int8_t)); | |||
| q.out_act_max, q.output.zp_, &q.quant_multiplier, &q.left_shift, &q.right_shift, p->row_, cur_oc_res, | |||
| p->col_ * sizeof(int8_t), 0); | |||
| #else | |||
| MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, q.out_act_min, q.out_act_max, q.output.zp_, | |||
| q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res, d16_, p->col_); | |||
| @@ -101,8 +101,8 @@ int MatmulInt8CPUKernel::RunImpl(int task_id) { | |||
| auto &p = quant_params_; | |||
| #ifdef ENABLE_ARM64 | |||
| MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, INT8_MIN, INT8_MAX, | |||
| p.output.zp_, p.quant_multiplier, p.left_shift, p.right_shift, params_->row_, cur_oc_res, | |||
| params_->col_ * sizeof(int8_t)); | |||
| p.output.zp_, &p.quant_multiplier, &p.left_shift, &p.right_shift, params_->row_, cur_oc_res, | |||
| params_->col_ * sizeof(int8_t), false); | |||
| #else | |||
| MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, INT8_MIN, INT8_MAX, p.output.zp_, p.quant_multiplier, | |||
| p.left_shift, p.right_shift, params_->row_, cur_oc_res, d16_, params_->col_); | |||
| @@ -120,8 +120,8 @@ TEST_F(TestMatmulInt8, simple) { | |||
| int multiplier, ls, rs; | |||
| QuantizeRoundParameter(1.0f, &multiplier, &ls, &rs); | |||
| #ifdef ENABLE_ARM64 | |||
| MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, multiplier, ls, | |||
| rs, ROW, COL, COL); | |||
| MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, &multiplier, &ls, | |||
| &rs, ROW, COL, COL, false); | |||
| #else | |||
| MatmulInt8(a_r4x16, b_c16x4, output, a_sums, bias, INT8_MIN, INT8_MAX, 0, multiplier, ls, rs, ROW, COL, DEPTH16, COL); | |||
| #endif | |||