Merge pull request !6102 from ling/srtags/v1.0.0
| @@ -6,9 +6,9 @@ | |||||
| .type MatmulInt8Neon64, %function | .type MatmulInt8Neon64, %function | ||||
| #endif | #endif | ||||
| //void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, | |||||
| // const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, | |||||
| // int multiplier, int left_shift, int right_shift, int row, int col, int stride); | |||||
| //void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, | |||||
| // const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, | |||||
| // int32_t *right_shift, int row, int col, int stride, int filter_peroc) | |||||
| // x0: a(left matrix ptr) | // x0: a(left matrix ptr) | ||||
| // x1: b(right matrix ptr) | // x1: b(right matrix ptr) | ||||
| @@ -21,31 +21,34 @@ | |||||
| // w8: act_min | // w8: act_min | ||||
| // w9: act_max | // w9: act_max | ||||
| // w10: out_zp | // w10: out_zp | ||||
| // w11: multiplier | |||||
| // w12: left_shift | |||||
| // w13: right_shift | |||||
| // x11: multiplier | |||||
| // x12: left_shift | |||||
| // x13: right_shift | |||||
| // w14: row | // w14: row | ||||
| // w15: col | // w15: col | ||||
| // w24: stride | // w24: stride | ||||
| // w27: filter_peroc | |||||
| MatmulInt8Neon64: | MatmulInt8Neon64: | ||||
| sub sp, sp, #192 | |||||
| sub sp, sp, #208 | |||||
| st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | ||||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | ||||
| stp x19, x20, [sp], #16 | stp x19, x20, [sp], #16 | ||||
| stp x21, x22, [sp], #16 | stp x21, x22, [sp], #16 | ||||
| stp x23, x24, [sp], #16 | stp x23, x24, [sp], #16 | ||||
| stp x25, x26, [sp], #16 | stp x25, x26, [sp], #16 | ||||
| stp x27, x28, [sp], #16 | |||||
| ldr w8, [sp] | ldr w8, [sp] | ||||
| ldr w9, [sp, #8] | ldr w9, [sp, #8] | ||||
| ldr w10, [sp, #16] | ldr w10, [sp, #16] | ||||
| ldr w11, [sp, #24] | |||||
| ldr w12, [sp, #32] | |||||
| ldr w13, [sp, #40] | |||||
| ldr x11, [sp, #24] | |||||
| ldr x12, [sp, #32] | |||||
| ldr x13, [sp, #40] | |||||
| ldr w14, [sp, #48] | ldr w14, [sp, #48] | ||||
| ldr w15, [sp, #56] | ldr w15, [sp, #56] | ||||
| ldr w24, [sp, #64] | ldr w24, [sp, #64] | ||||
| ldr w27, [sp, #72] | |||||
| mov w17, #4 // sizeof(int8)*4 | mov w17, #4 // sizeof(int8)*4 | ||||
| mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 | mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 | ||||
| @@ -58,7 +61,7 @@ L1: | |||||
| mov w16, w3 // reset a row4 counter | mov w16, w3 // reset a row4 counter | ||||
| mov w23, w14 // reset a row counter | mov w23, w14 // reset a row counter | ||||
| mov x17, x0 // reload a ptr | mov x17, x0 // reload a ptr | ||||
| mov x22, x6 // reload a_sums ptr | |||||
| mov x22, x6 // reload a_sums ptr | |||||
| L2: | L2: | ||||
| cmp w16, #0 | cmp w16, #0 | ||||
| beq End2 | beq End2 | ||||
| @@ -167,39 +170,60 @@ End3: | |||||
| addp v19.4s, v28.4s, v30.4s | addp v19.4s, v28.4s, v30.4s | ||||
| // Add (Bias+Depth*Za*Zb-Za*Bsums) | // Add (Bias+Depth*Za*Zb-Za*Bsums) | ||||
| ld1 {v15.4s}, [x19], #16 | |||||
| ld1 {v15.4s}, [x19], #16 | |||||
| add v16.4s, v16.4s, v15.4s | add v16.4s, v16.4s, v15.4s | ||||
| add v17.4s, v17.4s, v15.4s | add v17.4s, v17.4s, v15.4s | ||||
| add v18.4s, v18.4s, v15.4s | add v18.4s, v18.4s, v15.4s | ||||
| add v19.4s, v19.4s, v15.4s | add v19.4s, v19.4s, v15.4s | ||||
| // Subtract (Asums*Zb) | |||||
| cmp w27, #0 | |||||
| beq PerTLoad | |||||
| PerCLoad: | |||||
| ld1 {v20.4s}, [x6], #16 | |||||
| ld1 {v21.4s}, [x6], #16 | |||||
| ld1 {v22.4s}, [x6], #16 | |||||
| ld1 {v23.4s}, [x6], #16 | |||||
| ld1 {v13.4s}, [x12] | |||||
| ld1 {v12.4s}, [x11] | |||||
| ld1 {v11.4s}, [x13] | |||||
| b Apply | |||||
| PerTLoad: | |||||
| ld1 {v14.4s}, [x22], #16 | ld1 {v14.4s}, [x22], #16 | ||||
| dup v20.4s, v14.s[0] | dup v20.4s, v14.s[0] | ||||
| dup v21.4s, v14.s[1] | dup v21.4s, v14.s[1] | ||||
| dup v22.4s, v14.s[2] | dup v22.4s, v14.s[2] | ||||
| dup v23.4s, v14.s[3] | dup v23.4s, v14.s[3] | ||||
| ld1 {v14.s}[0], [x12] | |||||
| dup v13.4s, v14.s[0] | |||||
| ld1 {v14.s}[0], [x11] | |||||
| dup v12.4s, v14.s[0] | |||||
| ld1 {v14.s}[0], [x13] | |||||
| dup v11.4s, v14.s[0] | |||||
| b Apply | |||||
| Apply: | |||||
| // Subtract (Asums*Zb) | |||||
| sub v16.4s, v16.4s, v20.4s | sub v16.4s, v16.4s, v20.4s | ||||
| sub v17.4s, v17.4s, v21.4s | sub v17.4s, v17.4s, v21.4s | ||||
| sub v18.4s, v18.4s, v22.4s | sub v18.4s, v18.4s, v22.4s | ||||
| sub v19.4s, v19.4s, v23.4s | sub v19.4s, v19.4s, v23.4s | ||||
| // Apply left shift | // Apply left shift | ||||
| dup v13.4s, w12 | |||||
| sqshl v16.4s, v16.4s, v13.4s | sqshl v16.4s, v16.4s, v13.4s | ||||
| sqshl v17.4s, v17.4s, v13.4s | sqshl v17.4s, v17.4s, v13.4s | ||||
| sqshl v18.4s, v18.4s, v13.4s | sqshl v18.4s, v18.4s, v13.4s | ||||
| sqshl v19.4s, v19.4s, v13.4s | sqshl v19.4s, v19.4s, v13.4s | ||||
| // Apply the fixed-point part of the multiplier. | // Apply the fixed-point part of the multiplier. | ||||
| dup v12.4s, w11 | |||||
| sqrdmulh v16.4s, v16.4s, v12.4s | sqrdmulh v16.4s, v16.4s, v12.4s | ||||
| sqrdmulh v17.4s, v17.4s, v12.4s | sqrdmulh v17.4s, v17.4s, v12.4s | ||||
| sqrdmulh v18.4s, v18.4s, v12.4s | sqrdmulh v18.4s, v18.4s, v12.4s | ||||
| sqrdmulh v19.4s, v19.4s, v12.4s | sqrdmulh v19.4s, v19.4s, v12.4s | ||||
| // Apply right shift | // Apply right shift | ||||
| dup v11.4s, w13 | |||||
| and v20.16b, v11.16b, v16.16b | and v20.16b, v11.16b, v16.16b | ||||
| sshr v20.4s, v20.4s, #31 | sshr v20.4s, v20.4s, #31 | ||||
| sqadd v16.4s, v16.4s, v20.4s | sqadd v16.4s, v16.4s, v20.4s | ||||
| @@ -268,7 +292,7 @@ Write: | |||||
| beq WriteCol2 | beq WriteCol2 | ||||
| cmp w15, #1 | cmp w15, #1 | ||||
| beq WriteCol1 | beq WriteCol1 | ||||
| WriteCol4: | WriteCol4: | ||||
| st1 {v15.s}[0], [x2], x24 | st1 {v15.s}[0], [x2], x24 | ||||
| cmp w23, #1 | cmp w23, #1 | ||||
| @@ -349,7 +373,7 @@ WriteCol1: | |||||
| st1 {v15.b}[12], [x2], x24 | st1 {v15.b}[12], [x2], x24 | ||||
| b Endwrite | b Endwrite | ||||
| Endwrite: | |||||
| Endwrite: | |||||
| sub w16, w16, #4 // a row4 counter - 4 | sub w16, w16, #4 // a row4 counter - 4 | ||||
| sub w23, w23, #4 // a row counter - 4 | sub w23, w23, #4 // a row counter - 4 | ||||
| b L2 | b L2 | ||||
| @@ -361,15 +385,23 @@ End2: | |||||
| add x7, x7, #16 // bias ptr + stride | add x7, x7, #16 // bias ptr + stride | ||||
| add x25, x25, #4 // output + stride(4 * sizeof(int8)) | add x25, x25, #4 // output + stride(4 * sizeof(int8)) | ||||
| mov x2, x25 | mov x2, x25 | ||||
| cmp w27, #0 | |||||
| beq PerTEnd2 | |||||
| add x12, x12, #16 | |||||
| add x11, x11, #16 | |||||
| add x13, x13, #16 | |||||
| PerTEnd2: | |||||
| b L1 | b L1 | ||||
| End1: | End1: | ||||
| sub sp, sp, #192 | |||||
| sub sp, sp, #208 | |||||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | ||||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | ||||
| ldp x19, x20, [sp], #16 | ldp x19, x20, [sp], #16 | ||||
| ldp x21, x22, [sp], #16 | ldp x21, x22, [sp], #16 | ||||
| ldp x23, x24, [sp], #16 | ldp x23, x24, [sp], #16 | ||||
| ldp x25, x26, [sp], #16 | ldp x25, x26, [sp], #16 | ||||
| ldp x27, x28, [sp], #16 | |||||
| ret | ret | ||||
| #endif | #endif | ||||
| @@ -8,7 +8,7 @@ | |||||
| //void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4, | //void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4, | ||||
| // const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, | // const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, | ||||
| // int multiplier, int left_shift, int right_shift, int row, int col, int stride); | |||||
| // int *multiplier, int *left_shift, int *right_shift, int row, int col, int stride, int peroc); | |||||
| // x0: a(left matrix ptr) | // x0: a(left matrix ptr) | ||||
| // x1: b(right matrix ptr) | // x1: b(right matrix ptr) | ||||
| @@ -21,31 +21,34 @@ | |||||
| // w8: act_min | // w8: act_min | ||||
| // w9: act_max | // w9: act_max | ||||
| // w10: out_zp | // w10: out_zp | ||||
| // w11: multiplier | |||||
| // w12: left_shift | |||||
| // w13: right_shift | |||||
| // x11: multiplier | |||||
| // x12: left_shift | |||||
| // x13: right_shift | |||||
| // w14: row | // w14: row | ||||
| // w15: col | // w15: col | ||||
| // w24: stride | // w24: stride | ||||
| // w27: filter_peroc | |||||
| MatmulInt8DpNeon64: | MatmulInt8DpNeon64: | ||||
| sub sp, sp, #192 | |||||
| sub sp, sp, #208 | |||||
| st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | ||||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | ||||
| stp x19, x20, [sp], #16 | stp x19, x20, [sp], #16 | ||||
| stp x21, x22, [sp], #16 | stp x21, x22, [sp], #16 | ||||
| stp x23, x24, [sp], #16 | stp x23, x24, [sp], #16 | ||||
| stp x25, x26, [sp], #16 | stp x25, x26, [sp], #16 | ||||
| stp x27, x28, [sp], #16 | |||||
| ldr w8, [sp] | ldr w8, [sp] | ||||
| ldr w9, [sp, #8] | ldr w9, [sp, #8] | ||||
| ldr w10, [sp, #16] | ldr w10, [sp, #16] | ||||
| ldr w11, [sp, #24] | |||||
| ldr w12, [sp, #32] | |||||
| ldr w13, [sp, #40] | |||||
| ldr x11, [sp, #24] | |||||
| ldr x12, [sp, #32] | |||||
| ldr x13, [sp, #40] | |||||
| ldr w14, [sp, #48] | ldr w14, [sp, #48] | ||||
| ldr w15, [sp, #56] | ldr w15, [sp, #56] | ||||
| ldr w24, [sp, #64] | ldr w24, [sp, #64] | ||||
| ldr w27, [sp, #72] | |||||
| mov w17, #8 // sizeof(int8)*8 | mov w17, #8 // sizeof(int8)*8 | ||||
| mul w21, w5, w17 // the stride of a/b: sizeof(int8)*8*deep4 | mul w21, w5, w17 // the stride of a/b: sizeof(int8)*8*deep4 | ||||
| @@ -226,138 +229,171 @@ End3: | |||||
| add v29.4s, v29.4s, v14.4s | add v29.4s, v29.4s, v14.4s | ||||
| add v31.4s, v31.4s, v14.4s | add v31.4s, v31.4s, v14.4s | ||||
| cmp w27, #0 | |||||
| beq PerTSumLoad | |||||
| PerCSumLoad: | |||||
| ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x6], #64 | |||||
| ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x6], #64 | |||||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64 | |||||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 | |||||
| b ApplySum | |||||
| PerTSumLoad: | |||||
| ld1 {v14.4s}, [x22], #16 | |||||
| ld1 {v15.4s}, [x22], #16 | |||||
| dup v0.4s, v14.s[0] | |||||
| dup v1.4s, v14.s[0] | |||||
| dup v2.4s, v14.s[1] | |||||
| dup v3.4s, v14.s[1] | |||||
| dup v4.4s, v14.s[2] | |||||
| dup v5.4s, v14.s[2] | |||||
| dup v6.4s, v14.s[3] | |||||
| dup v7.4s, v14.s[3] | |||||
| dup v8.4s, v15.s[0] | |||||
| dup v9.4s, v15.s[0] | |||||
| dup v10.4s, v15.s[1] | |||||
| dup v11.4s, v15.s[1] | |||||
| dup v12.4s, v15.s[2] | |||||
| dup v13.4s, v15.s[2] | |||||
| dup v14.4s, v15.s[3] | |||||
| dup v15.4s, v14.s[0] | |||||
| ApplySum: | |||||
| // Subtract (Asums*Zb) | // Subtract (Asums*Zb) | ||||
| ld1 {v13.4s}, [x22], #16 | |||||
| ld1 {v12.4s}, [x22], #16 | |||||
| dup v0.4s, v13.s[0] | |||||
| dup v1.4s, v13.s[1] | |||||
| dup v2.4s, v13.s[2] | |||||
| dup v3.4s, v13.s[3] | |||||
| dup v4.4s, v12.s[0] | |||||
| dup v5.4s, v12.s[1] | |||||
| dup v6.4s, v12.s[2] | |||||
| dup v7.4s, v12.s[3] | |||||
| sub v16.4s, v16.4s, v0.4s | sub v16.4s, v16.4s, v0.4s | ||||
| sub v17.4s, v17.4s, v0.4s | |||||
| sub v18.4s, v18.4s, v1.4s | |||||
| sub v19.4s, v19.4s, v1.4s | |||||
| sub v20.4s, v20.4s, v2.4s | |||||
| sub v21.4s, v21.4s, v2.4s | |||||
| sub v22.4s, v22.4s, v3.4s | |||||
| sub v23.4s, v23.4s, v3.4s | |||||
| sub v24.4s, v24.4s, v4.4s | |||||
| sub v25.4s, v25.4s, v4.4s | |||||
| sub v26.4s, v26.4s, v5.4s | |||||
| sub v27.4s, v27.4s, v5.4s | |||||
| sub v28.4s, v28.4s, v6.4s | |||||
| sub v29.4s, v29.4s, v6.4s | |||||
| sub v30.4s, v30.4s, v7.4s | |||||
| sub v31.4s, v31.4s, v7.4s | |||||
| sub v17.4s, v17.4s, v1.4s | |||||
| sub v18.4s, v18.4s, v2.4s | |||||
| sub v19.4s, v19.4s, v3.4s | |||||
| sub v20.4s, v20.4s, v4.4s | |||||
| sub v21.4s, v21.4s, v5.4s | |||||
| sub v22.4s, v22.4s, v6.4s | |||||
| sub v23.4s, v23.4s, v7.4s | |||||
| sub v24.4s, v24.4s, v8.4s | |||||
| sub v25.4s, v25.4s, v9.4s | |||||
| sub v26.4s, v26.4s, v10.4s | |||||
| sub v27.4s, v27.4s, v11.4s | |||||
| sub v28.4s, v28.4s, v12.4s | |||||
| sub v29.4s, v29.4s, v13.4s | |||||
| sub v30.4s, v30.4s, v14.4s | |||||
| sub v31.4s, v31.4s, v15.4s | |||||
| cmp w27, #0 | |||||
| beq PerTRoundLoad | |||||
| PerCRoundLoad: | |||||
| ld1 {v8.4s, v9.4s}, [x12] | |||||
| ld1 {v10.4s, v11.4s}, [x11] | |||||
| ld1 {v12.4s, v13.4s}, [x13] | |||||
| b ApplyRound | |||||
| PerTRoundLoad: | |||||
| ld1 {v14.s}[0], [x12] | |||||
| dup v8.4s, v14.s[0] | |||||
| dup v9.4s, v14.s[0] | |||||
| ld1 {v14.s}[0], [x11] | |||||
| dup v10.4s, v14.s[0] | |||||
| dup v11.4s, v14.s[0] | |||||
| ld1 {v14.s}[0], [x13] | |||||
| dup v12.4s, v14.s[0] | |||||
| dup v13.4s, v14.s[0] | |||||
| ApplyRound: | |||||
| // Apply left shift | // Apply left shift | ||||
| dup v11.4s, w12 | |||||
| sqshl v16.4s, v16.4s, v11.4s | |||||
| sqshl v17.4s, v17.4s, v11.4s | |||||
| sqshl v18.4s, v18.4s, v11.4s | |||||
| sqshl v19.4s, v19.4s, v11.4s | |||||
| sqshl v20.4s, v20.4s, v11.4s | |||||
| sqshl v21.4s, v21.4s, v11.4s | |||||
| sqshl v22.4s, v22.4s, v11.4s | |||||
| sqshl v23.4s, v23.4s, v11.4s | |||||
| sqshl v24.4s, v24.4s, v11.4s | |||||
| sqshl v25.4s, v25.4s, v11.4s | |||||
| sqshl v26.4s, v26.4s, v11.4s | |||||
| sqshl v27.4s, v27.4s, v11.4s | |||||
| sqshl v28.4s, v28.4s, v11.4s | |||||
| sqshl v29.4s, v29.4s, v11.4s | |||||
| sqshl v30.4s, v30.4s, v11.4s | |||||
| sqshl v31.4s, v31.4s, v11.4s | |||||
| sqshl v16.4s, v16.4s, v8.4s | |||||
| sqshl v17.4s, v17.4s, v9.4s | |||||
| sqshl v18.4s, v18.4s, v8.4s | |||||
| sqshl v19.4s, v19.4s, v9.4s | |||||
| sqshl v20.4s, v20.4s, v8.4s | |||||
| sqshl v21.4s, v21.4s, v9.4s | |||||
| sqshl v22.4s, v22.4s, v8.4s | |||||
| sqshl v23.4s, v23.4s, v9.4s | |||||
| sqshl v24.4s, v24.4s, v8.4s | |||||
| sqshl v25.4s, v25.4s, v9.4s | |||||
| sqshl v26.4s, v26.4s, v8.4s | |||||
| sqshl v27.4s, v27.4s, v9.4s | |||||
| sqshl v28.4s, v28.4s, v8.4s | |||||
| sqshl v29.4s, v29.4s, v9.4s | |||||
| sqshl v30.4s, v30.4s, v8.4s | |||||
| sqshl v31.4s, v31.4s, v9.4s | |||||
| // Apply the fixed-point part of the multiplier. | // Apply the fixed-point part of the multiplier. | ||||
| dup v10.4s, w11 | |||||
| sqrdmulh v16.4s, v16.4s, v10.4s | sqrdmulh v16.4s, v16.4s, v10.4s | ||||
| sqrdmulh v17.4s, v17.4s, v10.4s | |||||
| sqrdmulh v17.4s, v17.4s, v11.4s | |||||
| sqrdmulh v18.4s, v18.4s, v10.4s | sqrdmulh v18.4s, v18.4s, v10.4s | ||||
| sqrdmulh v19.4s, v19.4s, v10.4s | |||||
| sqrdmulh v19.4s, v19.4s, v11.4s | |||||
| sqrdmulh v20.4s, v20.4s, v10.4s | sqrdmulh v20.4s, v20.4s, v10.4s | ||||
| sqrdmulh v21.4s, v21.4s, v10.4s | |||||
| sqrdmulh v21.4s, v21.4s, v11.4s | |||||
| sqrdmulh v22.4s, v22.4s, v10.4s | sqrdmulh v22.4s, v22.4s, v10.4s | ||||
| sqrdmulh v23.4s, v23.4s, v10.4s | |||||
| sqrdmulh v23.4s, v23.4s, v11.4s | |||||
| sqrdmulh v24.4s, v24.4s, v10.4s | sqrdmulh v24.4s, v24.4s, v10.4s | ||||
| sqrdmulh v25.4s, v25.4s, v10.4s | |||||
| sqrdmulh v25.4s, v25.4s, v11.4s | |||||
| sqrdmulh v26.4s, v26.4s, v10.4s | sqrdmulh v26.4s, v26.4s, v10.4s | ||||
| sqrdmulh v27.4s, v27.4s, v10.4s | |||||
| sqrdmulh v27.4s, v27.4s, v11.4s | |||||
| sqrdmulh v28.4s, v28.4s, v10.4s | sqrdmulh v28.4s, v28.4s, v10.4s | ||||
| sqrdmulh v29.4s, v29.4s, v10.4s | |||||
| sqrdmulh v29.4s, v29.4s, v11.4s | |||||
| sqrdmulh v30.4s, v30.4s, v10.4s | sqrdmulh v30.4s, v30.4s, v10.4s | ||||
| sqrdmulh v31.4s, v31.4s, v10.4s | |||||
| sqrdmulh v31.4s, v31.4s, v11.4s | |||||
| // Apply right shift | // Apply right shift | ||||
| dup v9.4s, w13 | |||||
| and v0.16b, v9.16b, v16.16b | |||||
| and v0.16b, v12.16b, v16.16b | |||||
| sshr v0.4s, v0.4s, #31 | sshr v0.4s, v0.4s, #31 | ||||
| sqadd v16.4s, v16.4s, v0.4s | sqadd v16.4s, v16.4s, v0.4s | ||||
| srshl v16.4s, v16.4s, v9.4s | |||||
| and v1.16b, v9.16b, v17.16b | |||||
| srshl v16.4s, v16.4s, v12.4s | |||||
| and v1.16b, v13.16b, v17.16b | |||||
| sshr v1.4s, v1.4s, #31 | sshr v1.4s, v1.4s, #31 | ||||
| sqadd v17.4s, v17.4s, v1.4s | sqadd v17.4s, v17.4s, v1.4s | ||||
| srshl v17.4s, v17.4s, v9.4s | |||||
| and v2.16b, v9.16b, v18.16b | |||||
| srshl v17.4s, v17.4s, v13.4s | |||||
| and v2.16b, v12.16b, v18.16b | |||||
| sshr v2.4s, v2.4s, #31 | sshr v2.4s, v2.4s, #31 | ||||
| sqadd v18.4s, v18.4s, v2.4s | sqadd v18.4s, v18.4s, v2.4s | ||||
| srshl v18.4s, v18.4s, v9.4s | |||||
| and v3.16b, v9.16b, v19.16b | |||||
| srshl v18.4s, v18.4s, v12.4s | |||||
| and v3.16b, v13.16b, v19.16b | |||||
| sshr v3.4s, v3.4s, #31 | sshr v3.4s, v3.4s, #31 | ||||
| sqadd v19.4s, v19.4s, v3.4s | sqadd v19.4s, v19.4s, v3.4s | ||||
| srshl v19.4s, v19.4s, v9.4s | |||||
| and v0.16b, v9.16b, v20.16b | |||||
| srshl v19.4s, v19.4s, v13.4s | |||||
| and v0.16b, v12.16b, v20.16b | |||||
| sshr v0.4s, v0.4s, #31 | sshr v0.4s, v0.4s, #31 | ||||
| sqadd v20.4s, v20.4s, v0.4s | sqadd v20.4s, v20.4s, v0.4s | ||||
| srshl v20.4s, v20.4s, v9.4s | |||||
| and v1.16b, v9.16b, v21.16b | |||||
| srshl v20.4s, v20.4s, v12.4s | |||||
| and v1.16b, v13.16b, v21.16b | |||||
| sshr v1.4s, v1.4s, #31 | sshr v1.4s, v1.4s, #31 | ||||
| sqadd v21.4s, v21.4s, v1.4s | sqadd v21.4s, v21.4s, v1.4s | ||||
| srshl v21.4s, v21.4s, v9.4s | |||||
| and v2.16b, v9.16b, v22.16b | |||||
| srshl v21.4s, v21.4s, v13.4s | |||||
| and v2.16b, v12.16b, v22.16b | |||||
| sshr v2.4s, v2.4s, #31 | sshr v2.4s, v2.4s, #31 | ||||
| sqadd v22.4s, v22.4s, v2.4s | sqadd v22.4s, v22.4s, v2.4s | ||||
| srshl v22.4s, v22.4s, v9.4s | |||||
| and v3.16b, v9.16b, v23.16b | |||||
| srshl v22.4s, v22.4s, v12.4s | |||||
| and v3.16b, v13.16b, v23.16b | |||||
| sshr v3.4s, v3.4s, #31 | sshr v3.4s, v3.4s, #31 | ||||
| sqadd v23.4s, v23.4s, v3.4s | sqadd v23.4s, v23.4s, v3.4s | ||||
| srshl v23.4s, v23.4s, v9.4s | |||||
| and v0.16b, v9.16b, v24.16b | |||||
| srshl v23.4s, v23.4s, v13.4s | |||||
| and v0.16b, v12.16b, v24.16b | |||||
| sshr v0.4s, v0.4s, #31 | sshr v0.4s, v0.4s, #31 | ||||
| sqadd v24.4s, v24.4s, v0.4s | sqadd v24.4s, v24.4s, v0.4s | ||||
| srshl v24.4s, v24.4s, v9.4s | |||||
| and v1.16b, v9.16b, v25.16b | |||||
| srshl v24.4s, v24.4s, v12.4s | |||||
| and v1.16b, v13.16b, v25.16b | |||||
| sshr v1.4s, v1.4s, #31 | sshr v1.4s, v1.4s, #31 | ||||
| sqadd v25.4s, v25.4s, v1.4s | sqadd v25.4s, v25.4s, v1.4s | ||||
| srshl v25.4s, v25.4s, v9.4s | |||||
| and v2.16b, v9.16b, v26.16b | |||||
| srshl v25.4s, v25.4s, v13.4s | |||||
| and v2.16b, v12.16b, v26.16b | |||||
| sshr v2.4s, v2.4s, #31 | sshr v2.4s, v2.4s, #31 | ||||
| sqadd v26.4s, v26.4s, v2.4s | sqadd v26.4s, v26.4s, v2.4s | ||||
| srshl v26.4s, v26.4s, v9.4s | |||||
| and v3.16b, v9.16b, v27.16b | |||||
| srshl v26.4s, v26.4s, v12.4s | |||||
| and v3.16b, v13.16b, v27.16b | |||||
| sshr v3.4s, v3.4s, #31 | sshr v3.4s, v3.4s, #31 | ||||
| sqadd v27.4s, v27.4s, v3.4s | sqadd v27.4s, v27.4s, v3.4s | ||||
| srshl v27.4s, v27.4s, v9.4s | |||||
| and v0.16b, v9.16b, v28.16b | |||||
| srshl v27.4s, v27.4s, v13.4s | |||||
| and v0.16b, v12.16b, v28.16b | |||||
| sshr v0.4s, v0.4s, #31 | sshr v0.4s, v0.4s, #31 | ||||
| sqadd v28.4s, v28.4s, v0.4s | sqadd v28.4s, v28.4s, v0.4s | ||||
| srshl v28.4s, v28.4s, v9.4s | |||||
| and v1.16b, v9.16b, v29.16b | |||||
| srshl v28.4s, v28.4s, v12.4s | |||||
| and v1.16b, v13.16b, v29.16b | |||||
| sshr v1.4s, v1.4s, #31 | sshr v1.4s, v1.4s, #31 | ||||
| sqadd v29.4s, v29.4s, v1.4s | sqadd v29.4s, v29.4s, v1.4s | ||||
| srshl v29.4s, v29.4s, v9.4s | |||||
| and v2.16b, v9.16b, v30.16b | |||||
| srshl v29.4s, v29.4s, v13.4s | |||||
| and v2.16b, v12.16b, v30.16b | |||||
| sshr v2.4s, v2.4s, #31 | sshr v2.4s, v2.4s, #31 | ||||
| sqadd v30.4s, v30.4s, v2.4s | sqadd v30.4s, v30.4s, v2.4s | ||||
| srshl v30.4s, v30.4s, v9.4s | |||||
| and v3.16b, v9.16b, v31.16b | |||||
| srshl v30.4s, v30.4s, v12.4s | |||||
| and v3.16b, v13.16b, v31.16b | |||||
| sshr v3.4s, v3.4s, #31 | sshr v3.4s, v3.4s, #31 | ||||
| sqadd v31.4s, v31.4s, v3.4s | sqadd v31.4s, v31.4s, v3.4s | ||||
| srshl v31.4s, v31.4s, v9.4s | |||||
| srshl v31.4s, v31.4s, v13.4s | |||||
| // Add the destination zero point | // Add the destination zero point | ||||
| dup v8.4s, w10 | dup v8.4s, w10 | ||||
| @@ -793,15 +829,23 @@ End2: | |||||
| add x7, x7, #32 // bias ptr + stride | add x7, x7, #32 // bias ptr + stride | ||||
| add x25, x25, #8 // output + stride(8 * sizeof(int8)) | add x25, x25, #8 // output + stride(8 * sizeof(int8)) | ||||
| mov x2, x25 | mov x2, x25 | ||||
| cmp w27, #0 | |||||
| beq PerTEnd2 | |||||
| add x12, x12, #32 | |||||
| add x11, x11, #32 | |||||
| add x13, x13, #32 | |||||
| PerTEnd2: | |||||
| b L1 | b L1 | ||||
| End1: | End1: | ||||
| sub sp, sp, #192 | |||||
| sub sp, sp, #208 | |||||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | ||||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | ||||
| ldp x19, x20, [sp], #16 | ldp x19, x20, [sp], #16 | ||||
| ldp x21, x22, [sp], #16 | ldp x21, x22, [sp], #16 | ||||
| ldp x23, x24, [sp], #16 | ldp x23, x24, [sp], #16 | ||||
| ldp x25, x26, [sp], #16 | ldp x25, x26, [sp], #16 | ||||
| ldp x27, x28, [sp], #16 | |||||
| ret | ret | ||||
| #endif | #endif | ||||
| @@ -385,12 +385,302 @@ void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t * | |||||
| int8_t *pack_r = packed_input; | int8_t *pack_r = packed_input; | ||||
| int32_t *input_sum_r = input_sum; | int32_t *input_sum_r = input_sum; | ||||
| /* per layer */ | |||||
| for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { | for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { | ||||
| const int8_t *src_ic = src_r; | const int8_t *src_ic = src_r; | ||||
| int8_t *pack_ic = pack_r; | int8_t *pack_ic = pack_r; | ||||
| int32_t *input_sum_oc = input_sum_r; | int32_t *input_sum_oc = input_sum_r; | ||||
| #ifdef ENABLE_ARM64 | |||||
| size_t src_stride = input_channel; | |||||
| size_t ic_4res = input_channel - ic_4div; | |||||
| size_t input_sum_stride = inputsum_stride * 4 - C8NUM * C8NUM * 4; | |||||
| asm volatile( | |||||
| "dup v16.4s, wzr \n" | |||||
| "dup v17.4s, wzr \n" | |||||
| "mov x10, %[src_ic] \n" | |||||
| "mov x11, %[pack_ic] \n" | |||||
| "mov x0, #0 \n" | |||||
| "1: \n" | |||||
| "cmp x0, %[ic_4div] \n" | |||||
| "add x0, x0, #4\n" | |||||
| "mov x12, x10 \n" | |||||
| "add x10, x10, #4\n" | |||||
| "blt 2f \n" | |||||
| "cmp %[ic_4res], #0\n" | |||||
| "beq 6f \n" | |||||
| "cmp %[ic_4res], #1\n" | |||||
| "beq 3f \n" | |||||
| "cmp %[ic_4res], #2\n" | |||||
| "beq 4f \n" | |||||
| "cmp %[ic_4res], #3\n" | |||||
| "beq 5f \n" | |||||
| "2: \n" | |||||
| "ld1 {v0.s}[0], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.s}[1], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.s}[2], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.s}[3], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.s}[0], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.s}[1], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.s}[2], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.s}[3], [x12], %[src_stride]\n" | |||||
| "st1 {v0.16b}, [x11], #16\n" | |||||
| "st1 {v1.16b}, [x11], #16\n" | |||||
| "saddlp v4.8h, v0.16b \n" | |||||
| "saddlp v5.8h, v1.16b \n" | |||||
| "saddlp v0.4s, v4.8h \n" | |||||
| "saddlp v1.4s, v5.8h \n" | |||||
| "add v16.4s, v16.4s, v0.4s \n" | |||||
| "add v17.4s, v17.4s, v1.4s \n" | |||||
| "b 1b \n" | |||||
| "3: \n" /* col res 1 */ | |||||
| "dup v0.4s, wzr \n" | |||||
| "dup v1.4s, wzr \n" | |||||
| "ld1 {v0.b}[0], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.b}[4], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.b}[8], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.b}[12], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.b}[0], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.b}[4], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.b}[8], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.b}[12], [x12], %[src_stride]\n" | |||||
| "st1 {v0.16b}, [x11], #16\n" | |||||
| "st1 {v1.16b}, [x11], #16\n" | |||||
| "saddlp v4.8h, v0.16b \n" | |||||
| "saddlp v5.8h, v1.16b \n" | |||||
| "saddlp v0.4s, v4.8h \n" | |||||
| "saddlp v1.4s, v5.8h \n" | |||||
| "add v16.4s, v16.4s, v0.4s \n" | |||||
| "add v17.4s, v17.4s, v1.4s \n" | |||||
| "b 6f \n" | |||||
| "4: \n" /* col res 2 */ | |||||
| "dup v0.4s, wzr \n" | |||||
| "dup v1.4s, wzr \n" | |||||
| "ld1 {v0.h}[0], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.h}[2], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.h}[4], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.h}[6], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.h}[0], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.h}[2], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.h}[4], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.h}[6], [x12], %[src_stride]\n" | |||||
| "st1 {v0.16b}, [x11], #16\n" | |||||
| "st1 {v1.16b}, [x11], #16\n" | |||||
| "saddlp v4.8h, v0.16b \n" | |||||
| "saddlp v5.8h, v1.16b \n" | |||||
| "saddlp v0.4s, v4.8h \n" | |||||
| "saddlp v1.4s, v5.8h \n" | |||||
| "add v16.4s, v16.4s, v0.4s \n" | |||||
| "add v17.4s, v17.4s, v1.4s \n" | |||||
| "b 6f \n" | |||||
| "5: \n" /* col res 3 */ | |||||
| "dup v0.4s, wzr \n" | |||||
| "dup v1.4s, wzr \n" | |||||
| "add x13, x12, #2 \n" | |||||
| "ld1 {v0.h}[0], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.b}[2], [x13], %[src_stride]\n" | |||||
| "ld1 {v0.h}[2], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.b}[6], [x13], %[src_stride]\n" | |||||
| "ld1 {v0.h}[4], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.b}[10], [x13], %[src_stride]\n" | |||||
| "ld1 {v0.h}[6], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.b}[14], [x13], %[src_stride]\n" | |||||
| "ld1 {v1.h}[0], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.b}[2], [x13], %[src_stride]\n" | |||||
| "ld1 {v1.h}[2], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.b}[6], [x13], %[src_stride]\n" | |||||
| "ld1 {v1.h}[4], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.b}[10], [x13], %[src_stride]\n" | |||||
| "ld1 {v1.h}[6], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.b}[14], [x13], %[src_stride]\n" | |||||
| "st1 {v0.16b}, [x11], #16\n" | |||||
| "st1 {v1.16b}, [x11], #16\n" | |||||
| "saddlp v4.8h, v0.16b \n" | |||||
| "saddlp v5.8h, v1.16b \n" | |||||
| "saddlp v0.4s, v4.8h \n" | |||||
| "saddlp v1.4s, v5.8h \n" | |||||
| "add v16.4s, v16.4s, v0.4s \n" | |||||
| "add v17.4s, v17.4s, v1.4s \n" | |||||
| "b 6f \n" | |||||
| "6: \n" | |||||
| "dup v0.4s, v16.s[0] \n" | |||||
| "dup v1.4s, v16.s[1] \n" | |||||
| "dup v2.4s, v16.s[2] \n" | |||||
| "dup v3.4s, v16.s[3] \n" | |||||
| "dup v4.4s, v17.s[0] \n" | |||||
| "dup v5.4s, v17.s[1] \n" | |||||
| "dup v6.4s, v17.s[2] \n" | |||||
| "dup v7.4s, v17.s[3] \n" | |||||
| "mov x4, #0 \n" | |||||
| "mov x10, %[filter_zp] \n" | |||||
| "mov x11, %[input_sum_oc] \n" | |||||
| "7: \n" | |||||
| "cmp x4, %[oc_8div] \n" | |||||
| "beq 8f \n" | |||||
| "add x4, x4, #8\n" | |||||
| "ld1 {v16.4s}, [x10], #16\n" | |||||
| "ld1 {v17.4s}, [x10], #16\n" | |||||
| "mul v18.4s, v16.4s, v0.4s \n" | |||||
| "mul v19.4s, v17.4s, v0.4s \n" | |||||
| "st1 {v18.4s}, [x11], #16 \n" | |||||
| "st1 {v19.4s}, [x11], #16 \n" | |||||
| "mul v20.4s, v16.4s, v1.4s \n" | |||||
| "mul v21.4s, v17.4s, v1.4s \n" | |||||
| "st1 {v20.4s}, [x11], #16 \n" | |||||
| "st1 {v21.4s}, [x11], #16 \n" | |||||
| "mul v22.4s, v16.4s, v2.4s \n" | |||||
| "mul v23.4s, v17.4s, v2.4s \n" | |||||
| "st1 {v22.4s}, [x11], #16 \n" | |||||
| "st1 {v23.4s}, [x11], #16 \n" | |||||
| "mul v24.4s, v16.4s, v3.4s \n" | |||||
| "mul v25.4s, v17.4s, v3.4s \n" | |||||
| "st1 {v24.4s}, [x11], #16 \n" | |||||
| "st1 {v25.4s}, [x11], #16 \n" | |||||
| "mul v18.4s, v16.4s, v4.4s \n" | |||||
| "mul v19.4s, v17.4s, v4.4s \n" | |||||
| "st1 {v18.4s}, [x11], #16 \n" | |||||
| "st1 {v19.4s}, [x11], #16 \n" | |||||
| "mul v20.4s, v16.4s, v5.4s \n" | |||||
| "mul v21.4s, v17.4s, v5.4s \n" | |||||
| "st1 {v20.4s}, [x11], #16 \n" | |||||
| "st1 {v21.4s}, [x11], #16 \n" | |||||
| "mul v22.4s, v16.4s, v6.4s \n" | |||||
| "mul v23.4s, v17.4s, v6.4s \n" | |||||
| "st1 {v22.4s}, [x11], #16 \n" | |||||
| "st1 {v23.4s}, [x11], #16 \n" | |||||
| "mul v24.4s, v16.4s, v7.4s \n" | |||||
| "mul v25.4s, v17.4s, v7.4s \n" | |||||
| "st1 {v24.4s}, [x11], #16 \n" | |||||
| "st1 {v25.4s}, [x11], #16 \n" | |||||
| "add x11, x11, %[input_sum_stride] \n" | |||||
| "b 7b \n" | |||||
| "8: \n" | |||||
| "cmp %[oc_8res], #0\n" | |||||
| "beq 17f \n" | |||||
| "dup v16.4s, wzr \n" | |||||
| "dup v17.4s, wzr \n" | |||||
| "cmp %[oc_8res], #1\n" | |||||
| "beq 9f \n" | |||||
| "cmp %[oc_8res], #2\n" | |||||
| "beq 10f \n" | |||||
| "cmp %[oc_8res], #3\n" | |||||
| "beq 11f \n" | |||||
| "cmp %[oc_8res], #4\n" | |||||
| "beq 12f \n" | |||||
| "cmp %[oc_8res], #5\n" | |||||
| "beq 13f \n" | |||||
| "cmp %[oc_8res], #6\n" | |||||
| "beq 14f \n" | |||||
| "cmp %[oc_8res], #7\n" | |||||
| "beq 15f \n" | |||||
| "9: \n" | |||||
| "ld1 {v16.s}[0], [x10] \n" | |||||
| "b 16f \n" | |||||
| "10: \n" | |||||
| "ld1 {v16.h}[0], [x10] \n" | |||||
| "b 16f \n" | |||||
| "11: \n" | |||||
| "ld1 {v16.h}[0], [x10] \n" | |||||
| "add x10, x10, #8 \n" | |||||
| "ld1 {v16.s}[2], [x10] \n" | |||||
| "b 16f \n" | |||||
| "12: \n" | |||||
| "ld1 {v16.4s}, [x10] \n" | |||||
| "b 16f \n" | |||||
| "13: \n" | |||||
| "ld1 {v16.4s}, [x10], #16\n" | |||||
| "ld1 {v17.s}[0], [x10] \n" | |||||
| "b 16f \n" | |||||
| "14: \n" | |||||
| "ld1 {v16.4s}, [x10], #16\n" | |||||
| "ld1 {v17.h}[0], [x10] \n" | |||||
| "b 16f \n" | |||||
| "15: \n" | |||||
| "ld1 {v16.4s}, [x10], #16\n" | |||||
| "ld1 {v17.h}[0], [x10] \n" | |||||
| "add x10, x10, #8 \n" | |||||
| "ld1 {v17.s}[2], [x10] \n" | |||||
| "b 16f \n" | |||||
| "16: \n" | |||||
| "mul v18.4s, v16.4s, v0.4s \n" | |||||
| "mul v19.4s, v17.4s, v0.4s \n" | |||||
| "mul v20.4s, v16.4s, v1.4s \n" | |||||
| "mul v21.4s, v17.4s, v1.4s \n" | |||||
| "mul v22.4s, v16.4s, v2.4s \n" | |||||
| "mul v23.4s, v17.4s, v2.4s \n" | |||||
| "mul v24.4s, v16.4s, v3.4s \n" | |||||
| "mul v25.4s, v17.4s, v3.4s \n" | |||||
| "st1 {v18.4s}, [x11], #16 \n" | |||||
| "st1 {v19.4s}, [x11], #16 \n" | |||||
| "st1 {v20.4s}, [x11], #16 \n" | |||||
| "st1 {v21.4s}, [x11], #16 \n" | |||||
| "st1 {v22.4s}, [x11], #16 \n" | |||||
| "st1 {v23.4s}, [x11], #16 \n" | |||||
| "st1 {v24.4s}, [x11], #16 \n" | |||||
| "st1 {v25.4s}, [x11], #16 \n" | |||||
| "mul v18.4s, v16.4s, v4.4s \n" | |||||
| "mul v19.4s, v17.4s, v4.4s \n" | |||||
| "mul v20.4s, v16.4s, v5.4s \n" | |||||
| "mul v21.4s, v17.4s, v5.4s \n" | |||||
| "mul v22.4s, v16.4s, v6.4s \n" | |||||
| "mul v23.4s, v17.4s, v6.4s \n" | |||||
| "mul v24.4s, v16.4s, v7.4s \n" | |||||
| "mul v25.4s, v17.4s, v7.4s \n" | |||||
| "st1 {v18.4s}, [x11], #16 \n" | |||||
| "st1 {v19.4s}, [x11], #16 \n" | |||||
| "st1 {v20.4s}, [x11], #16 \n" | |||||
| "st1 {v21.4s}, [x11], #16 \n" | |||||
| "st1 {v22.4s}, [x11], #16 \n" | |||||
| "st1 {v23.4s}, [x11], #16 \n" | |||||
| "st1 {v24.4s}, [x11], #16 \n" | |||||
| "st1 {v25.4s}, [x11], #16 \n" | |||||
| "17: \n" | |||||
| : | |||||
| : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ filter_zp ] "r"(filter_zp), | |||||
| [ input_sum_oc ] "r"(input_sum_oc), [ input_sum_stride ] "r"(input_sum_stride), [ src_stride ] "r"(src_stride), | |||||
| [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ oc_8div ] "r"(oc_8div), [ oc_8res ] "r"(oc_8res) | |||||
| : "x0", "x1", "x4", "x9", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", | |||||
| "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25"); | |||||
| #else | |||||
| int32_t tmp_sum_value[8] = {0}; | int32_t tmp_sum_value[8] = {0}; | ||||
| for (int ici = 0; ici < ic_4div; ici += C4NUM) { | for (int ici = 0; ici < ic_4div; ici += C4NUM) { | ||||
| for (int i = 0; i < C8NUM; i++) { | for (int i = 0; i < C8NUM; i++) { | ||||
| @@ -440,7 +730,7 @@ void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t * | |||||
| } | } | ||||
| } | } | ||||
| } /* oc8 res done */ | } /* oc8 res done */ | ||||
| #endif | |||||
| src_r += input_channel * C8NUM; | src_r += input_channel * C8NUM; | ||||
| pack_r += ic4 * C8NUM; | pack_r += ic4 * C8NUM; | ||||
| input_sum_r += C8NUM * C8NUM; | input_sum_r += C8NUM * C8NUM; | ||||
| @@ -520,9 +810,9 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i | |||||
| size_t src_stride = input_channel; | size_t src_stride = input_channel; | ||||
| size_t ic_4res = input_channel - ic_4div; | size_t ic_4res = input_channel - ic_4div; | ||||
| asm volatile( | asm volatile( | ||||
| "dup v10.4s, wzr \n" | |||||
| "dup v11.4s, wzr \n" | |||||
| "mov x20, %[input_sum_r] \n" | |||||
| "dup v16.4s, wzr \n" | |||||
| "dup v17.4s, wzr \n" | |||||
| "mov x14, %[input_sum_r] \n" | |||||
| "dup v20.4s, %w[filter_zp] \n" | "dup v20.4s, %w[filter_zp] \n" | ||||
| "mov x10, %[src_ic] \n" | "mov x10, %[src_ic] \n" | ||||
| @@ -563,8 +853,8 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i | |||||
| "saddlp v0.4s, v4.8h \n" | "saddlp v0.4s, v4.8h \n" | ||||
| "saddlp v1.4s, v5.8h \n" | "saddlp v1.4s, v5.8h \n" | ||||
| "add v10.4s, v10.4s, v0.4s \n" | |||||
| "add v11.4s, v11.4s, v1.4s \n" | |||||
| "add v16.4s, v16.4s, v0.4s \n" | |||||
| "add v17.4s, v17.4s, v1.4s \n" | |||||
| "b 1b \n" | "b 1b \n" | ||||
| "3: \n" /* col res 1 */ | "3: \n" /* col res 1 */ | ||||
| @@ -586,8 +876,8 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i | |||||
| "saddlp v5.8h, v1.16b \n" | "saddlp v5.8h, v1.16b \n" | ||||
| "saddlp v0.4s, v4.8h \n" | "saddlp v0.4s, v4.8h \n" | ||||
| "saddlp v1.4s, v5.8h \n" | "saddlp v1.4s, v5.8h \n" | ||||
| "add v10.4s, v10.4s, v0.4s \n" | |||||
| "add v11.4s, v11.4s, v1.4s \n" | |||||
| "add v16.4s, v16.4s, v0.4s \n" | |||||
| "add v17.4s, v17.4s, v1.4s \n" | |||||
| "b 6f \n" | "b 6f \n" | ||||
| "4: \n" /* col res 2 */ | "4: \n" /* col res 2 */ | ||||
| @@ -609,8 +899,8 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i | |||||
| "saddlp v5.8h, v1.16b \n" | "saddlp v5.8h, v1.16b \n" | ||||
| "saddlp v0.4s, v4.8h \n" | "saddlp v0.4s, v4.8h \n" | ||||
| "saddlp v1.4s, v5.8h \n" | "saddlp v1.4s, v5.8h \n" | ||||
| "add v10.4s, v10.4s, v0.4s \n" | |||||
| "add v11.4s, v11.4s, v1.4s \n" | |||||
| "add v16.4s, v16.4s, v0.4s \n" | |||||
| "add v17.4s, v17.4s, v1.4s \n" | |||||
| "b 6f \n" | "b 6f \n" | ||||
| "5: \n" /* col res 3 */ | "5: \n" /* col res 3 */ | ||||
| @@ -641,21 +931,21 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i | |||||
| "saddlp v5.8h, v1.16b \n" | "saddlp v5.8h, v1.16b \n" | ||||
| "saddlp v0.4s, v4.8h \n" | "saddlp v0.4s, v4.8h \n" | ||||
| "saddlp v1.4s, v5.8h \n" | "saddlp v1.4s, v5.8h \n" | ||||
| "add v10.4s, v10.4s, v0.4s \n" | |||||
| "add v11.4s, v11.4s, v1.4s \n" | |||||
| "add v16.4s, v16.4s, v0.4s \n" | |||||
| "add v17.4s, v17.4s, v1.4s \n" | |||||
| "b 6f \n" | "b 6f \n" | ||||
| "6: \n" | "6: \n" | ||||
| "mul v10.4s, v10.4s, v20.4s \n" | |||||
| "mul v11.4s, v11.4s, v20.4s \n" | |||||
| "mul v16.4s, v16.4s, v20.4s \n" | |||||
| "mul v17.4s, v17.4s, v20.4s \n" | |||||
| "st1 {v10.4s}, [x20], #16 \n" | |||||
| "st1 {v11.4s}, [x20], #16 \n" | |||||
| "st1 {v16.4s}, [x14], #16 \n" | |||||
| "st1 {v17.4s}, [x14], #16 \n" | |||||
| : | : | ||||
| : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r), | : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r), | ||||
| [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp) | [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp) | ||||
| : "x0", "x1", "x10", "x11", "x12", "x13", "x20", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v11", | |||||
| : "x0", "x1", "x10", "x11", "x12", "x13", "x14", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", | |||||
| "v20"); | "v20"); | ||||
| #else | #else | ||||
| int32_t tmp_sum_value[8] = {0}; | int32_t tmp_sum_value[8] = {0}; | ||||
| @@ -728,10 +1018,10 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i | |||||
| void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | ||||
| const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, | const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, | ||||
| int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func) { | int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func) { | ||||
| int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; | |||||
| matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias, | matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias, | ||||
| left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | ||||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | |||||
| conv_param->conv_quant_arg_.filter_arg_num_ != 1); | |||||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc); | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -756,24 +1046,17 @@ void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, i | |||||
| void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | ||||
| const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, | const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, | ||||
| int32_t *multiplier, ConvParameter *conv_param) { | int32_t *multiplier, ConvParameter *conv_param) { | ||||
| if (conv_param->conv_quant_arg_.filter_arg_num_ > 1) { | |||||
| return MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, | |||||
| bias, left_shift, right_shift, multiplier, | |||||
| conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | |||||
| conv_param->conv_quant_arg_.filter_arg_num_ != 1); | |||||
| } | |||||
| int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; | |||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| MatmulInt8Neon64(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum, | MatmulInt8Neon64(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum, | ||||
| bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | ||||
| conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||||
| conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], | |||||
| conv_param->conv_quant_arg_.right_shift_[0], row, col, conv_param->output_channel_); | |||||
| conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, row, col, | |||||
| conv_param->output_channel_, is_per_oc); | |||||
| #else | #else | ||||
| MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, | MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, | ||||
| left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | ||||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | ||||
| conv_param->conv_quant_arg_.filter_arg_num_ != 1); | |||||
| is_per_oc); | |||||
| #endif | #endif | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -269,7 +269,7 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, | |||||
| void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | ||||
| size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | ||||
| int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, | int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, | ||||
| bool per_channel) { | |||||
| size_t per_channel) { | |||||
| /* row8x4-major * row4x8-major => (int8)row-major */ | /* row8x4-major * row4x8-major => (int8)row-major */ | ||||
| for (int r = 0; r < row; r++) { | for (int r = 0; r < row; r++) { | ||||
| for (int c = 0; c < col; c++) { | for (int c = 0; c < col; c++) { | ||||
| @@ -39,7 +39,7 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col); | |||||
| void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | ||||
| size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | ||||
| int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, | int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, | ||||
| bool per_channel); | |||||
| size_t per_channel); | |||||
| void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | ||||
| void RowMajor2Row4x8MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | void RowMajor2Row4x8MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | ||||
| void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16); | void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16); | ||||
| @@ -59,8 +59,8 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, | |||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, | void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, | ||||
| const int *bias, int act_min, int act_max, int out_zp, int multiplier, int left_shift, | |||||
| int right_shift, int row, int col, int stride); | |||||
| const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, | |||||
| int32_t *right_shift, int row, int col, int stride, int filter_peroc); | |||||
| void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, | void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, | ||||
| const int *input_sum, const int *bias); | const int *input_sum, const int *bias); | ||||
| @@ -25,7 +25,7 @@ typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, i | |||||
| typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | ||||
| size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | ||||
| int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, | int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, | ||||
| int32_t maxi, bool per_channel); | |||||
| int32_t maxi, size_t per_channel); | |||||
| typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col); | typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col); | ||||
| @@ -30,8 +30,9 @@ extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_ | |||||
| extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, | extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, | ||||
| const int *input_sum, const int *bias); | const int *input_sum, const int *bias); | ||||
| extern void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4, | extern void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4, | ||||
| const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int multiplier, | |||||
| int left_shift, int right_shift, int row, int col, int stride); | |||||
| const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, | |||||
| int *multiplier, int *left_shift, int *right_shift, int row, int col, int stride, | |||||
| size_t peroc); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| @@ -55,8 +56,8 @@ void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, i | |||||
| void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | ||||
| size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | ||||
| int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, | int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, | ||||
| int32_t maxi, bool per_channel) { | |||||
| int32_t maxi, size_t per_channel) { | |||||
| return MatmulInt8DpNeon64(a, b, dst, UP_ROUND(row, 8), UP_ROUND(col, 8), deep_4, input_sum, bias, mini, maxi, | return MatmulInt8DpNeon64(a, b, dst, UP_ROUND(row, 8), UP_ROUND(col, 8), deep_4, input_sum, bias, mini, maxi, | ||||
| output_zp, multiplier[0], left_shift[0], right_shift[0], row, col, stride); | |||||
| output_zp, multiplier, left_shift, right_shift, row, col, stride, per_channel); | |||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -273,13 +273,13 @@ void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, i | |||||
| "mov x11, %[input_sum] \n" | "mov x11, %[input_sum] \n" | ||||
| "mov x15, %[filter_zp_ptr] \n" | "mov x15, %[filter_zp_ptr] \n" | ||||
| "mov x0, #0 \n" // row 4 count | |||||
| "mov x0, #0 \n" | |||||
| "1: \n" | "1: \n" | ||||
| "cmp x0, %[hw4] \n" | "cmp x0, %[hw4] \n" | ||||
| "beq 11f \n" | "beq 11f \n" | ||||
| "add x0, x0, #4\n" | "add x0, x0, #4\n" | ||||
| "dup v10.4s, wzr \n" | "dup v10.4s, wzr \n" | ||||
| "mov x2, #0 \n" // input deep count | |||||
| "mov x2, #0 \n" | |||||
| "mov x16, x15 \n" | "mov x16, x15 \n" | ||||
| "2: \n" | "2: \n" | ||||
| @@ -313,9 +313,9 @@ void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, i | |||||
| "b 2b \n" | "b 2b \n" | ||||
| "3: \n" | "3: \n" | ||||
| "mov x12, x11 \n" // tmp inputsm inputsum hw | |||||
| "mov x12, x11 \n" | |||||
| "add x11, x11, #64 \n" | "add x11, x11, #64 \n" | ||||
| "mov x4, #0 \n" // oc count | |||||
| "mov x4, #0 \n" | |||||
| "dup v1.4s, v10.s[0] \n" | "dup v1.4s, v10.s[0] \n" | ||||
| "dup v2.4s, v10.s[1] \n" | "dup v2.4s, v10.s[1] \n" | ||||
| @@ -46,6 +46,18 @@ Convolution1x1Int8CPUKernel::~Convolution1x1Int8CPUKernel() { | |||||
| free(filter_zp_ptr_); | free(filter_zp_ptr_); | ||||
| filter_zp_ptr_ = nullptr; | filter_zp_ptr_ = nullptr; | ||||
| } | } | ||||
| if (filter_peroc_ && left_shift_ != nullptr) { | |||||
| free(left_shift_); | |||||
| left_shift_ = nullptr; | |||||
| } | |||||
| if (filter_peroc_ && right_shift_ != nullptr) { | |||||
| free(right_shift_); | |||||
| right_shift_ = nullptr; | |||||
| } | |||||
| if (filter_peroc_ && multiplier_ != nullptr) { | |||||
| free(multiplier_); | |||||
| multiplier_ = nullptr; | |||||
| } | |||||
| FreeResizeBuf(); | FreeResizeBuf(); | ||||
| FreeQuantParam(); | FreeQuantParam(); | ||||
| } | } | ||||
| @@ -59,7 +71,7 @@ void Convolution1x1Int8CPUKernel::FreeResizeBuf() { | |||||
| } | } | ||||
| void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { | void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { | ||||
| support_optimize_ = false; | |||||
| support_optimize_ = true; | |||||
| matmul_func_ = MatMulInt8_8x8_r; | matmul_func_ = MatMulInt8_8x8_r; | ||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; | void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; | ||||
| @@ -78,10 +90,6 @@ void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { | |||||
| support_optimize_ = false; | support_optimize_ = false; | ||||
| matmul_func_ = nullptr; | matmul_func_ = nullptr; | ||||
| } | } | ||||
| if (filter_peroc_) { | |||||
| support_optimize_ = false; | |||||
| } | |||||
| #endif | #endif | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -109,6 +117,26 @@ int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channe | |||||
| for (int fi = 0; fi < output_channel; fi++) { | for (int fi = 0; fi < output_channel; fi++) { | ||||
| filter_zp_ptr_[fi] = conv_param_->conv_quant_arg_.filter_quant_args_[fi].zp_; | filter_zp_ptr_[fi] = conv_param_->conv_quant_arg_.filter_quant_args_[fi].zp_; | ||||
| } | } | ||||
| int up_round_oc_size = support_optimize_ ? UP_ROUND(output_channel, C8NUM) : UP_ROUND(output_channel, C4NUM); | |||||
| left_shift_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t))); | |||||
| if (left_shift_ == nullptr) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(left_shift_, 0, up_round_oc_size * sizeof(int32_t)); | |||||
| memcpy(left_shift_, conv_param_->conv_quant_arg_.left_shift_, output_channel * sizeof(int32_t)); | |||||
| right_shift_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t))); | |||||
| if (right_shift_ == nullptr) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(right_shift_, 0, up_round_oc_size * sizeof(int32_t)); | |||||
| memcpy(right_shift_, conv_param_->conv_quant_arg_.right_shift_, output_channel * sizeof(int32_t)); | |||||
| multiplier_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t))); | |||||
| if (multiplier_ == nullptr) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(multiplier_, 0, up_round_oc_size * sizeof(int32_t)); | |||||
| memcpy(multiplier_, conv_param_->conv_quant_arg_.quant_multiplier_, output_channel * sizeof(int32_t)); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -328,9 +356,9 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { | |||||
| } | } | ||||
| if (filter_peroc_) { | if (filter_peroc_) { | ||||
| cur_input_sum = input_sum_ + task_id * matmul_param_->row_8_ * thread_stride_ * C8NUM; | cur_input_sum = input_sum_ + task_id * matmul_param_->row_8_ * thread_stride_ * C8NUM; | ||||
| cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C8NUM; | |||||
| cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C8NUM; | |||||
| cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C8NUM; | |||||
| cur_left_shift = left_shift_ + task_id * thread_stride_ * C8NUM; | |||||
| cur_right_shift = right_shift_ + task_id * thread_stride_ * C8NUM; | |||||
| cur_multiplier = multiplier_ + task_id * thread_stride_ * C8NUM; | |||||
| } | } | ||||
| Conv1x1Int8Opt(packed_input_, packed_weight_ + task_id * thread_stride_ * C8NUM * matmul_param_->deep_4_, | Conv1x1Int8Opt(packed_input_, packed_weight_ + task_id * thread_stride_ * C8NUM * matmul_param_->deep_4_, | ||||
| output_ptr_ + task_id * thread_stride_ * C8NUM, cur_input_sum, | output_ptr_ + task_id * thread_stride_ * C8NUM, cur_input_sum, | ||||
| @@ -346,9 +374,9 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { | |||||
| } | } | ||||
| if (filter_peroc_) { | if (filter_peroc_) { | ||||
| cur_input_sum = input_sum_ + task_id * matmul_param_->row_4_ * thread_stride_ * C4NUM; | cur_input_sum = input_sum_ + task_id * matmul_param_->row_4_ * thread_stride_ * C4NUM; | ||||
| cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C4NUM; | |||||
| cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C4NUM; | |||||
| cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C4NUM; | |||||
| cur_left_shift = left_shift_ + task_id * thread_stride_ * C4NUM; | |||||
| cur_right_shift = right_shift_ + task_id * thread_stride_ * C4NUM; | |||||
| cur_multiplier = multiplier_ + task_id * thread_stride_ * C4NUM; | |||||
| } | } | ||||
| Conv1x1Int8(packed_input_, packed_weight_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_, | Conv1x1Int8(packed_input_, packed_weight_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_, | ||||
| output_ptr_ + task_id * thread_stride_ * C4NUM, cur_input_sum, | output_ptr_ + task_id * thread_stride_ * C4NUM, cur_input_sum, | ||||
| @@ -58,8 +58,11 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| int InitBiasByzp(void *src_weight, int input_channel, int output_channel); | int InitBiasByzp(void *src_weight, int input_channel, int output_channel); | ||||
| private: | private: | ||||
| int32_t *input_sum_ = nullptr; /* per-channel: oc4 format */ | |||||
| int32_t *filter_zp_ptr_ = nullptr; /* oc - per - channel */ | |||||
| int32_t *input_sum_ = nullptr; /* per-oc: oc4 format */ | |||||
| int32_t *filter_zp_ptr_ = nullptr; /* per-oc */ | |||||
| int32_t *left_shift_ = nullptr; /* per-oc up round */ | |||||
| int32_t *right_shift_ = nullptr; /* per-oc up round */ | |||||
| int32_t *multiplier_ = nullptr; /* per-oc up round */ | |||||
| int8_t *packed_weight_ = nullptr; | int8_t *packed_weight_ = nullptr; | ||||
| int8_t *packed_input_ = nullptr; | int8_t *packed_input_ = nullptr; | ||||
| int8_t *input_ptr_ = nullptr; | int8_t *input_ptr_ = nullptr; | ||||
| @@ -108,8 +108,8 @@ int FullconnectionInt8CPUKernel::RunImpl(int task_id) { | |||||
| auto cur_c = output_ptr + task_id * thread_stride_ * C4NUM; | auto cur_c = output_ptr + task_id * thread_stride_ * C4NUM; | ||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, q.out_act_min, | MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, q.out_act_min, | ||||
| q.out_act_max, q.output.zp_, q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res, | |||||
| p->col_ * sizeof(int8_t)); | |||||
| q.out_act_max, q.output.zp_, &q.quant_multiplier, &q.left_shift, &q.right_shift, p->row_, cur_oc_res, | |||||
| p->col_ * sizeof(int8_t), 0); | |||||
| #else | #else | ||||
| MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, q.out_act_min, q.out_act_max, q.output.zp_, | MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, q.out_act_min, q.out_act_max, q.output.zp_, | ||||
| q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res, d16_, p->col_); | q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res, d16_, p->col_); | ||||
| @@ -101,8 +101,8 @@ int MatmulInt8CPUKernel::RunImpl(int task_id) { | |||||
| auto &p = quant_params_; | auto &p = quant_params_; | ||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, INT8_MIN, INT8_MAX, | MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, INT8_MIN, INT8_MAX, | ||||
| p.output.zp_, p.quant_multiplier, p.left_shift, p.right_shift, params_->row_, cur_oc_res, | |||||
| params_->col_ * sizeof(int8_t)); | |||||
| p.output.zp_, &p.quant_multiplier, &p.left_shift, &p.right_shift, params_->row_, cur_oc_res, | |||||
| params_->col_ * sizeof(int8_t), false); | |||||
| #else | #else | ||||
| MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, INT8_MIN, INT8_MAX, p.output.zp_, p.quant_multiplier, | MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, INT8_MIN, INT8_MAX, p.output.zp_, p.quant_multiplier, | ||||
| p.left_shift, p.right_shift, params_->row_, cur_oc_res, d16_, params_->col_); | p.left_shift, p.right_shift, params_->row_, cur_oc_res, d16_, params_->col_); | ||||
| @@ -120,8 +120,8 @@ TEST_F(TestMatmulInt8, simple) { | |||||
| int multiplier, ls, rs; | int multiplier, ls, rs; | ||||
| QuantizeRoundParameter(1.0f, &multiplier, &ls, &rs); | QuantizeRoundParameter(1.0f, &multiplier, &ls, &rs); | ||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, multiplier, ls, | |||||
| rs, ROW, COL, COL); | |||||
| MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, &multiplier, &ls, | |||||
| &rs, ROW, COL, COL, false); | |||||
| #else | #else | ||||
| MatmulInt8(a_r4x16, b_c16x4, output, a_sums, bias, INT8_MIN, INT8_MAX, 0, multiplier, ls, rs, ROW, COL, DEPTH16, COL); | MatmulInt8(a_r4x16, b_c16x4, output, a_sums, bias, INT8_MIN, INT8_MAX, 0, multiplier, ls, rs, ROW, COL, DEPTH16, COL); | ||||
| #endif | #endif | ||||