diff --git a/mindspore/lite/nnacl/assembly/arm64/MatmulInt8.S b/mindspore/lite/nnacl/assembly/arm64/MatmulInt8.S index b7d834da2d..9974e5c771 100644 --- a/mindspore/lite/nnacl/assembly/arm64/MatmulInt8.S +++ b/mindspore/lite/nnacl/assembly/arm64/MatmulInt8.S @@ -6,9 +6,9 @@ .type MatmulInt8Neon64, %function #endif -//void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, -// const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, -// int multiplier, int left_shift, int right_shift, int row, int col, int stride); +//void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, +// const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, +// int32_t *right_shift, int row, int col, int stride, int filter_peroc) // x0: a(left matrix ptr) // x1: b(right matrix ptr) @@ -21,31 +21,34 @@ // w8: act_min // w9: act_max // w10: out_zp -// w11: multiplier -// w12: left_shift -// w13: right_shift +// x11: multiplier +// x12: left_shift +// x13: right_shift // w14: row // w15: col // w24: stride +// w27: filter_peroc MatmulInt8Neon64: - sub sp, sp, #192 + sub sp, sp, #208 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 stp x19, x20, [sp], #16 stp x21, x22, [sp], #16 stp x23, x24, [sp], #16 stp x25, x26, [sp], #16 + stp x27, x28, [sp], #16 ldr w8, [sp] ldr w9, [sp, #8] ldr w10, [sp, #16] - ldr w11, [sp, #24] - ldr w12, [sp, #32] - ldr w13, [sp, #40] + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] ldr w14, [sp, #48] ldr w15, [sp, #56] ldr w24, [sp, #64] + ldr w27, [sp, #72] mov w17, #4 // sizeof(int8)*4 mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 @@ -58,7 +61,7 @@ L1: mov w16, w3 // reset a row4 counter mov w23, w14 // reset a row counter mov x17, x0 // reload a ptr - mov x22, x6 // reload a_sums ptr + mov x22, x6 // reload a_sums ptr L2: cmp w16, #0 beq End2 @@ -167,39 +170,60 @@ End3: addp v19.4s, v28.4s, v30.4s // Add (Bias+Depth*Za*Zb-Za*Bsums) - ld1 {v15.4s}, [x19], #16 + ld1 {v15.4s}, [x19], #16 add v16.4s, v16.4s, v15.4s add v17.4s, v17.4s, v15.4s add v18.4s, v18.4s, v15.4s add v19.4s, v19.4s, v15.4s - // Subtract (Asums*Zb) + cmp w27, #0 + beq PerTLoad +PerCLoad: + ld1 {v20.4s}, [x6], #16 + ld1 {v21.4s}, [x6], #16 + ld1 {v22.4s}, [x6], #16 + ld1 {v23.4s}, [x6], #16 + + ld1 {v13.4s}, [x12] + ld1 {v12.4s}, [x11] + ld1 {v11.4s}, [x13] + b Apply + +PerTLoad: ld1 {v14.4s}, [x22], #16 dup v20.4s, v14.s[0] dup v21.4s, v14.s[1] dup v22.4s, v14.s[2] dup v23.4s, v14.s[3] + + ld1 {v14.s}[0], [x12] + dup v13.4s, v14.s[0] + ld1 {v14.s}[0], [x11] + dup v12.4s, v14.s[0] + ld1 {v14.s}[0], [x13] + dup v11.4s, v14.s[0] + b Apply + +Apply: + // Subtract (Asums*Zb) sub v16.4s, v16.4s, v20.4s sub v17.4s, v17.4s, v21.4s sub v18.4s, v18.4s, v22.4s sub v19.4s, v19.4s, v23.4s // Apply left shift - dup v13.4s, w12 sqshl v16.4s, v16.4s, v13.4s sqshl v17.4s, v17.4s, v13.4s sqshl v18.4s, v18.4s, v13.4s sqshl v19.4s, v19.4s, v13.4s // Apply the fixed-point part of the multiplier. - dup v12.4s, w11 sqrdmulh v16.4s, v16.4s, v12.4s sqrdmulh v17.4s, v17.4s, v12.4s sqrdmulh v18.4s, v18.4s, v12.4s sqrdmulh v19.4s, v19.4s, v12.4s // Apply right shift - dup v11.4s, w13 and v20.16b, v11.16b, v16.16b sshr v20.4s, v20.4s, #31 sqadd v16.4s, v16.4s, v20.4s @@ -268,7 +292,7 @@ Write: beq WriteCol2 cmp w15, #1 beq WriteCol1 - + WriteCol4: st1 {v15.s}[0], [x2], x24 cmp w23, #1 @@ -349,7 +373,7 @@ WriteCol1: st1 {v15.b}[12], [x2], x24 b Endwrite -Endwrite: +Endwrite: sub w16, w16, #4 // a row4 counter - 4 sub w23, w23, #4 // a row counter - 4 b L2 @@ -361,15 +385,23 @@ End2: add x7, x7, #16 // bias ptr + stride add x25, x25, #4 // output + stride(4 * sizeof(int8)) mov x2, x25 + + cmp w27, #0 + beq PerTEnd2 + add x12, x12, #16 + add x11, x11, #16 + add x13, x13, #16 +PerTEnd2: b L1 End1: - sub sp, sp, #192 + sub sp, sp, #208 ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 ldp x19, x20, [sp], #16 ldp x21, x22, [sp], #16 ldp x23, x24, [sp], #16 ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 ret #endif diff --git a/mindspore/lite/nnacl/assembly/opt/MatmulDpInt8.S b/mindspore/lite/nnacl/assembly/opt/MatmulDpInt8.S index 7c4bc49bc9..077131ba99 100644 --- a/mindspore/lite/nnacl/assembly/opt/MatmulDpInt8.S +++ b/mindspore/lite/nnacl/assembly/opt/MatmulDpInt8.S @@ -8,7 +8,7 @@ //void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4, // const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, -// int multiplier, int left_shift, int right_shift, int row, int col, int stride); +// int *multiplier, int *left_shift, int *right_shift, int row, int col, int stride, int peroc); // x0: a(left matrix ptr) // x1: b(right matrix ptr) @@ -21,31 +21,34 @@ // w8: act_min // w9: act_max // w10: out_zp -// w11: multiplier -// w12: left_shift -// w13: right_shift +// x11: multiplier +// x12: left_shift +// x13: right_shift // w14: row // w15: col // w24: stride +// w27: filter_peroc MatmulInt8DpNeon64: - sub sp, sp, #192 + sub sp, sp, #208 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 stp x19, x20, [sp], #16 stp x21, x22, [sp], #16 stp x23, x24, [sp], #16 stp x25, x26, [sp], #16 + stp x27, x28, [sp], #16 ldr w8, [sp] ldr w9, [sp, #8] ldr w10, [sp, #16] - ldr w11, [sp, #24] - ldr w12, [sp, #32] - ldr w13, [sp, #40] + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] ldr w14, [sp, #48] ldr w15, [sp, #56] ldr w24, [sp, #64] + ldr w27, [sp, #72] mov w17, #8 // sizeof(int8)*8 mul w21, w5, w17 // the stride of a/b: sizeof(int8)*8*deep4 @@ -226,138 +229,171 @@ End3: add v29.4s, v29.4s, v14.4s add v31.4s, v31.4s, v14.4s + cmp w27, #0 + beq PerTSumLoad +PerCSumLoad: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x6], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x6], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 + b ApplySum +PerTSumLoad: + ld1 {v14.4s}, [x22], #16 + ld1 {v15.4s}, [x22], #16 + dup v0.4s, v14.s[0] + dup v1.4s, v14.s[0] + dup v2.4s, v14.s[1] + dup v3.4s, v14.s[1] + dup v4.4s, v14.s[2] + dup v5.4s, v14.s[2] + dup v6.4s, v14.s[3] + dup v7.4s, v14.s[3] + dup v8.4s, v15.s[0] + dup v9.4s, v15.s[0] + dup v10.4s, v15.s[1] + dup v11.4s, v15.s[1] + dup v12.4s, v15.s[2] + dup v13.4s, v15.s[2] + dup v14.4s, v15.s[3] + dup v15.4s, v14.s[0] +ApplySum: // Subtract (Asums*Zb) - ld1 {v13.4s}, [x22], #16 - ld1 {v12.4s}, [x22], #16 - dup v0.4s, v13.s[0] - dup v1.4s, v13.s[1] - dup v2.4s, v13.s[2] - dup v3.4s, v13.s[3] - dup v4.4s, v12.s[0] - dup v5.4s, v12.s[1] - dup v6.4s, v12.s[2] - dup v7.4s, v12.s[3] sub v16.4s, v16.4s, v0.4s - sub v17.4s, v17.4s, v0.4s - sub v18.4s, v18.4s, v1.4s - sub v19.4s, v19.4s, v1.4s - sub v20.4s, v20.4s, v2.4s - sub v21.4s, v21.4s, v2.4s - sub v22.4s, v22.4s, v3.4s - sub v23.4s, v23.4s, v3.4s - sub v24.4s, v24.4s, v4.4s - sub v25.4s, v25.4s, v4.4s - sub v26.4s, v26.4s, v5.4s - sub v27.4s, v27.4s, v5.4s - sub v28.4s, v28.4s, v6.4s - sub v29.4s, v29.4s, v6.4s - sub v30.4s, v30.4s, v7.4s - sub v31.4s, v31.4s, v7.4s - + sub v17.4s, v17.4s, v1.4s + sub v18.4s, v18.4s, v2.4s + sub v19.4s, v19.4s, v3.4s + sub v20.4s, v20.4s, v4.4s + sub v21.4s, v21.4s, v5.4s + sub v22.4s, v22.4s, v6.4s + sub v23.4s, v23.4s, v7.4s + sub v24.4s, v24.4s, v8.4s + sub v25.4s, v25.4s, v9.4s + sub v26.4s, v26.4s, v10.4s + sub v27.4s, v27.4s, v11.4s + sub v28.4s, v28.4s, v12.4s + sub v29.4s, v29.4s, v13.4s + sub v30.4s, v30.4s, v14.4s + sub v31.4s, v31.4s, v15.4s + + cmp w27, #0 + beq PerTRoundLoad +PerCRoundLoad: + ld1 {v8.4s, v9.4s}, [x12] + ld1 {v10.4s, v11.4s}, [x11] + ld1 {v12.4s, v13.4s}, [x13] + b ApplyRound +PerTRoundLoad: + ld1 {v14.s}[0], [x12] + dup v8.4s, v14.s[0] + dup v9.4s, v14.s[0] + ld1 {v14.s}[0], [x11] + dup v10.4s, v14.s[0] + dup v11.4s, v14.s[0] + ld1 {v14.s}[0], [x13] + dup v12.4s, v14.s[0] + dup v13.4s, v14.s[0] +ApplyRound: // Apply left shift - dup v11.4s, w12 - sqshl v16.4s, v16.4s, v11.4s - sqshl v17.4s, v17.4s, v11.4s - sqshl v18.4s, v18.4s, v11.4s - sqshl v19.4s, v19.4s, v11.4s - sqshl v20.4s, v20.4s, v11.4s - sqshl v21.4s, v21.4s, v11.4s - sqshl v22.4s, v22.4s, v11.4s - sqshl v23.4s, v23.4s, v11.4s - sqshl v24.4s, v24.4s, v11.4s - sqshl v25.4s, v25.4s, v11.4s - sqshl v26.4s, v26.4s, v11.4s - sqshl v27.4s, v27.4s, v11.4s - sqshl v28.4s, v28.4s, v11.4s - sqshl v29.4s, v29.4s, v11.4s - sqshl v30.4s, v30.4s, v11.4s - sqshl v31.4s, v31.4s, v11.4s + sqshl v16.4s, v16.4s, v8.4s + sqshl v17.4s, v17.4s, v9.4s + sqshl v18.4s, v18.4s, v8.4s + sqshl v19.4s, v19.4s, v9.4s + sqshl v20.4s, v20.4s, v8.4s + sqshl v21.4s, v21.4s, v9.4s + sqshl v22.4s, v22.4s, v8.4s + sqshl v23.4s, v23.4s, v9.4s + sqshl v24.4s, v24.4s, v8.4s + sqshl v25.4s, v25.4s, v9.4s + sqshl v26.4s, v26.4s, v8.4s + sqshl v27.4s, v27.4s, v9.4s + sqshl v28.4s, v28.4s, v8.4s + sqshl v29.4s, v29.4s, v9.4s + sqshl v30.4s, v30.4s, v8.4s + sqshl v31.4s, v31.4s, v9.4s // Apply the fixed-point part of the multiplier. - dup v10.4s, w11 sqrdmulh v16.4s, v16.4s, v10.4s - sqrdmulh v17.4s, v17.4s, v10.4s + sqrdmulh v17.4s, v17.4s, v11.4s sqrdmulh v18.4s, v18.4s, v10.4s - sqrdmulh v19.4s, v19.4s, v10.4s + sqrdmulh v19.4s, v19.4s, v11.4s sqrdmulh v20.4s, v20.4s, v10.4s - sqrdmulh v21.4s, v21.4s, v10.4s + sqrdmulh v21.4s, v21.4s, v11.4s sqrdmulh v22.4s, v22.4s, v10.4s - sqrdmulh v23.4s, v23.4s, v10.4s + sqrdmulh v23.4s, v23.4s, v11.4s sqrdmulh v24.4s, v24.4s, v10.4s - sqrdmulh v25.4s, v25.4s, v10.4s + sqrdmulh v25.4s, v25.4s, v11.4s sqrdmulh v26.4s, v26.4s, v10.4s - sqrdmulh v27.4s, v27.4s, v10.4s + sqrdmulh v27.4s, v27.4s, v11.4s sqrdmulh v28.4s, v28.4s, v10.4s - sqrdmulh v29.4s, v29.4s, v10.4s + sqrdmulh v29.4s, v29.4s, v11.4s sqrdmulh v30.4s, v30.4s, v10.4s - sqrdmulh v31.4s, v31.4s, v10.4s + sqrdmulh v31.4s, v31.4s, v11.4s // Apply right shift - dup v9.4s, w13 - and v0.16b, v9.16b, v16.16b + and v0.16b, v12.16b, v16.16b sshr v0.4s, v0.4s, #31 sqadd v16.4s, v16.4s, v0.4s - srshl v16.4s, v16.4s, v9.4s - and v1.16b, v9.16b, v17.16b + srshl v16.4s, v16.4s, v12.4s + and v1.16b, v13.16b, v17.16b sshr v1.4s, v1.4s, #31 sqadd v17.4s, v17.4s, v1.4s - srshl v17.4s, v17.4s, v9.4s - and v2.16b, v9.16b, v18.16b + srshl v17.4s, v17.4s, v13.4s + and v2.16b, v12.16b, v18.16b sshr v2.4s, v2.4s, #31 sqadd v18.4s, v18.4s, v2.4s - srshl v18.4s, v18.4s, v9.4s - and v3.16b, v9.16b, v19.16b + srshl v18.4s, v18.4s, v12.4s + and v3.16b, v13.16b, v19.16b sshr v3.4s, v3.4s, #31 sqadd v19.4s, v19.4s, v3.4s - srshl v19.4s, v19.4s, v9.4s - and v0.16b, v9.16b, v20.16b + srshl v19.4s, v19.4s, v13.4s + and v0.16b, v12.16b, v20.16b sshr v0.4s, v0.4s, #31 sqadd v20.4s, v20.4s, v0.4s - srshl v20.4s, v20.4s, v9.4s - and v1.16b, v9.16b, v21.16b + srshl v20.4s, v20.4s, v12.4s + and v1.16b, v13.16b, v21.16b sshr v1.4s, v1.4s, #31 sqadd v21.4s, v21.4s, v1.4s - srshl v21.4s, v21.4s, v9.4s - and v2.16b, v9.16b, v22.16b + srshl v21.4s, v21.4s, v13.4s + and v2.16b, v12.16b, v22.16b sshr v2.4s, v2.4s, #31 sqadd v22.4s, v22.4s, v2.4s - srshl v22.4s, v22.4s, v9.4s - and v3.16b, v9.16b, v23.16b + srshl v22.4s, v22.4s, v12.4s + and v3.16b, v13.16b, v23.16b sshr v3.4s, v3.4s, #31 sqadd v23.4s, v23.4s, v3.4s - srshl v23.4s, v23.4s, v9.4s - and v0.16b, v9.16b, v24.16b + srshl v23.4s, v23.4s, v13.4s + and v0.16b, v12.16b, v24.16b sshr v0.4s, v0.4s, #31 sqadd v24.4s, v24.4s, v0.4s - srshl v24.4s, v24.4s, v9.4s - and v1.16b, v9.16b, v25.16b + srshl v24.4s, v24.4s, v12.4s + and v1.16b, v13.16b, v25.16b sshr v1.4s, v1.4s, #31 sqadd v25.4s, v25.4s, v1.4s - srshl v25.4s, v25.4s, v9.4s - and v2.16b, v9.16b, v26.16b + srshl v25.4s, v25.4s, v13.4s + and v2.16b, v12.16b, v26.16b sshr v2.4s, v2.4s, #31 sqadd v26.4s, v26.4s, v2.4s - srshl v26.4s, v26.4s, v9.4s - and v3.16b, v9.16b, v27.16b + srshl v26.4s, v26.4s, v12.4s + and v3.16b, v13.16b, v27.16b sshr v3.4s, v3.4s, #31 sqadd v27.4s, v27.4s, v3.4s - srshl v27.4s, v27.4s, v9.4s - and v0.16b, v9.16b, v28.16b + srshl v27.4s, v27.4s, v13.4s + and v0.16b, v12.16b, v28.16b sshr v0.4s, v0.4s, #31 sqadd v28.4s, v28.4s, v0.4s - srshl v28.4s, v28.4s, v9.4s - and v1.16b, v9.16b, v29.16b + srshl v28.4s, v28.4s, v12.4s + and v1.16b, v13.16b, v29.16b sshr v1.4s, v1.4s, #31 sqadd v29.4s, v29.4s, v1.4s - srshl v29.4s, v29.4s, v9.4s - and v2.16b, v9.16b, v30.16b + srshl v29.4s, v29.4s, v13.4s + and v2.16b, v12.16b, v30.16b sshr v2.4s, v2.4s, #31 sqadd v30.4s, v30.4s, v2.4s - srshl v30.4s, v30.4s, v9.4s - and v3.16b, v9.16b, v31.16b + srshl v30.4s, v30.4s, v12.4s + and v3.16b, v13.16b, v31.16b sshr v3.4s, v3.4s, #31 sqadd v31.4s, v31.4s, v3.4s - srshl v31.4s, v31.4s, v9.4s + srshl v31.4s, v31.4s, v13.4s // Add the destination zero point dup v8.4s, w10 @@ -793,15 +829,23 @@ End2: add x7, x7, #32 // bias ptr + stride add x25, x25, #8 // output + stride(8 * sizeof(int8)) mov x2, x25 + + cmp w27, #0 + beq PerTEnd2 + add x12, x12, #32 + add x11, x11, #32 + add x13, x13, #32 +PerTEnd2: b L1 End1: - sub sp, sp, #192 + sub sp, sp, #208 ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 ldp x19, x20, [sp], #16 ldp x21, x22, [sp], #16 ldp x23, x24, [sp], #16 ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 ret #endif diff --git a/mindspore/lite/nnacl/int8/conv_int8.c b/mindspore/lite/nnacl/int8/conv_int8.c index 5dccd3b772..f1e87c8682 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.c +++ b/mindspore/lite/nnacl/int8/conv_int8.c @@ -385,12 +385,302 @@ void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t * int8_t *pack_r = packed_input; int32_t *input_sum_r = input_sum; - /* per layer */ for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { const int8_t *src_ic = src_r; int8_t *pack_ic = pack_r; int32_t *input_sum_oc = input_sum_r; +#ifdef ENABLE_ARM64 + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + size_t input_sum_stride = inputsum_stride * 4 - C8NUM * C8NUM * 4; + asm volatile( + "dup v16.4s, wzr \n" + "dup v17.4s, wzr \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x0, #0 \n" + "1: \n" + "cmp x0, %[ic_4div] \n" + "add x0, x0, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + "ld1 {v1.s}[0], [x12], %[src_stride]\n" + "ld1 {v1.s}[1], [x12], %[src_stride]\n" + "ld1 {v1.s}[2], [x12], %[src_stride]\n" + "ld1 {v1.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 1b \n" + + "3: \n" /* col res 1 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + "ld1 {v1.b}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[8], [x12], %[src_stride]\n" + "ld1 {v1.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "4: \n" /* col res 2 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "5: \n" /* col res 3 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[2], [x13], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.b}[6], [x13], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[10], [x13], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "6: \n" + "dup v0.4s, v16.s[0] \n" + "dup v1.4s, v16.s[1] \n" + "dup v2.4s, v16.s[2] \n" + "dup v3.4s, v16.s[3] \n" + "dup v4.4s, v17.s[0] \n" + "dup v5.4s, v17.s[1] \n" + "dup v6.4s, v17.s[2] \n" + "dup v7.4s, v17.s[3] \n" + "mov x4, #0 \n" + "mov x10, %[filter_zp] \n" + "mov x11, %[input_sum_oc] \n" + + "7: \n" + "cmp x4, %[oc_8div] \n" + "beq 8f \n" + "add x4, x4, #8\n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.4s}, [x10], #16\n" + + "mul v18.4s, v16.4s, v0.4s \n" + "mul v19.4s, v17.4s, v0.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + + "mul v20.4s, v16.4s, v1.4s \n" + "mul v21.4s, v17.4s, v1.4s \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + + "mul v22.4s, v16.4s, v2.4s \n" + "mul v23.4s, v17.4s, v2.4s \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + + "mul v24.4s, v16.4s, v3.4s \n" + "mul v25.4s, v17.4s, v3.4s \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "mul v18.4s, v16.4s, v4.4s \n" + "mul v19.4s, v17.4s, v4.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + + "mul v20.4s, v16.4s, v5.4s \n" + "mul v21.4s, v17.4s, v5.4s \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + + "mul v22.4s, v16.4s, v6.4s \n" + "mul v23.4s, v17.4s, v6.4s \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + + "mul v24.4s, v16.4s, v7.4s \n" + "mul v25.4s, v17.4s, v7.4s \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "add x11, x11, %[input_sum_stride] \n" + "b 7b \n" + + "8: \n" + "cmp %[oc_8res], #0\n" + "beq 17f \n" + + "dup v16.4s, wzr \n" + "dup v17.4s, wzr \n" + "cmp %[oc_8res], #1\n" + "beq 9f \n" + "cmp %[oc_8res], #2\n" + "beq 10f \n" + "cmp %[oc_8res], #3\n" + "beq 11f \n" + "cmp %[oc_8res], #4\n" + "beq 12f \n" + "cmp %[oc_8res], #5\n" + "beq 13f \n" + "cmp %[oc_8res], #6\n" + "beq 14f \n" + "cmp %[oc_8res], #7\n" + "beq 15f \n" + + "9: \n" + "ld1 {v16.s}[0], [x10] \n" + "b 16f \n" + + "10: \n" + "ld1 {v16.h}[0], [x10] \n" + "b 16f \n" + + "11: \n" + "ld1 {v16.h}[0], [x10] \n" + "add x10, x10, #8 \n" + "ld1 {v16.s}[2], [x10] \n" + "b 16f \n" + + "12: \n" + "ld1 {v16.4s}, [x10] \n" + "b 16f \n" + + "13: \n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.s}[0], [x10] \n" + "b 16f \n" + + "14: \n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.h}[0], [x10] \n" + "b 16f \n" + + "15: \n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.h}[0], [x10] \n" + "add x10, x10, #8 \n" + "ld1 {v17.s}[2], [x10] \n" + "b 16f \n" + + "16: \n" + "mul v18.4s, v16.4s, v0.4s \n" + "mul v19.4s, v17.4s, v0.4s \n" + "mul v20.4s, v16.4s, v1.4s \n" + "mul v21.4s, v17.4s, v1.4s \n" + "mul v22.4s, v16.4s, v2.4s \n" + "mul v23.4s, v17.4s, v2.4s \n" + "mul v24.4s, v16.4s, v3.4s \n" + "mul v25.4s, v17.4s, v3.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "mul v18.4s, v16.4s, v4.4s \n" + "mul v19.4s, v17.4s, v4.4s \n" + "mul v20.4s, v16.4s, v5.4s \n" + "mul v21.4s, v17.4s, v5.4s \n" + "mul v22.4s, v16.4s, v6.4s \n" + "mul v23.4s, v17.4s, v6.4s \n" + "mul v24.4s, v16.4s, v7.4s \n" + "mul v25.4s, v17.4s, v7.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "17: \n" + + : + : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ filter_zp ] "r"(filter_zp), + [ input_sum_oc ] "r"(input_sum_oc), [ input_sum_stride ] "r"(input_sum_stride), [ src_stride ] "r"(src_stride), + [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ oc_8div ] "r"(oc_8div), [ oc_8res ] "r"(oc_8res) + : "x0", "x1", "x4", "x9", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25"); +#else int32_t tmp_sum_value[8] = {0}; for (int ici = 0; ici < ic_4div; ici += C4NUM) { for (int i = 0; i < C8NUM; i++) { @@ -440,7 +730,7 @@ void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t * } } } /* oc8 res done */ - +#endif src_r += input_channel * C8NUM; pack_r += ic4 * C8NUM; input_sum_r += C8NUM * C8NUM; @@ -520,9 +810,9 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i size_t src_stride = input_channel; size_t ic_4res = input_channel - ic_4div; asm volatile( - "dup v10.4s, wzr \n" - "dup v11.4s, wzr \n" - "mov x20, %[input_sum_r] \n" + "dup v16.4s, wzr \n" + "dup v17.4s, wzr \n" + "mov x14, %[input_sum_r] \n" "dup v20.4s, %w[filter_zp] \n" "mov x10, %[src_ic] \n" @@ -563,8 +853,8 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i "saddlp v0.4s, v4.8h \n" "saddlp v1.4s, v5.8h \n" - "add v10.4s, v10.4s, v0.4s \n" - "add v11.4s, v11.4s, v1.4s \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" "b 1b \n" "3: \n" /* col res 1 */ @@ -586,8 +876,8 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i "saddlp v5.8h, v1.16b \n" "saddlp v0.4s, v4.8h \n" "saddlp v1.4s, v5.8h \n" - "add v10.4s, v10.4s, v0.4s \n" - "add v11.4s, v11.4s, v1.4s \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" "b 6f \n" "4: \n" /* col res 2 */ @@ -609,8 +899,8 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i "saddlp v5.8h, v1.16b \n" "saddlp v0.4s, v4.8h \n" "saddlp v1.4s, v5.8h \n" - "add v10.4s, v10.4s, v0.4s \n" - "add v11.4s, v11.4s, v1.4s \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" "b 6f \n" "5: \n" /* col res 3 */ @@ -641,21 +931,21 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i "saddlp v5.8h, v1.16b \n" "saddlp v0.4s, v4.8h \n" "saddlp v1.4s, v5.8h \n" - "add v10.4s, v10.4s, v0.4s \n" - "add v11.4s, v11.4s, v1.4s \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" "b 6f \n" "6: \n" - "mul v10.4s, v10.4s, v20.4s \n" - "mul v11.4s, v11.4s, v20.4s \n" + "mul v16.4s, v16.4s, v20.4s \n" + "mul v17.4s, v17.4s, v20.4s \n" - "st1 {v10.4s}, [x20], #16 \n" - "st1 {v11.4s}, [x20], #16 \n" + "st1 {v16.4s}, [x14], #16 \n" + "st1 {v17.4s}, [x14], #16 \n" : : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r), [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp) - : "x0", "x1", "x10", "x11", "x12", "x13", "x20", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v11", + : "x0", "x1", "x10", "x11", "x12", "x13", "x14", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v20"); #else int32_t tmp_sum_value[8] = {0}; @@ -728,10 +1018,10 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func) { + int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias, left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], - conv_param->conv_quant_arg_.filter_arg_num_ != 1); + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc); return; } @@ -756,24 +1046,17 @@ void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, i void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, ConvParameter *conv_param) { - if (conv_param->conv_quant_arg_.filter_arg_num_ > 1) { - return MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, - bias, left_shift, right_shift, multiplier, - conv_param->conv_quant_arg_.output_quant_args_[0].zp_, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], - conv_param->conv_quant_arg_.filter_arg_num_ != 1); - } + int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; #ifdef ENABLE_ARM64 MatmulInt8Neon64(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum, bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], - conv_param->conv_quant_arg_.output_quant_args_[0].zp_, - conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], - conv_param->conv_quant_arg_.right_shift_[0], row, col, conv_param->output_channel_); + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, row, col, + conv_param->output_channel_, is_per_oc); #else MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], - conv_param->conv_quant_arg_.filter_arg_num_ != 1); + is_per_oc); #endif return; } diff --git a/mindspore/lite/nnacl/int8/matmul_int8.c b/mindspore/lite/nnacl/int8/matmul_int8.c index 488bc9693e..06545b77b5 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.c +++ b/mindspore/lite/nnacl/int8/matmul_int8.c @@ -269,7 +269,7 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, - bool per_channel) { + size_t per_channel) { /* row8x4-major * row4x8-major => (int8)row-major */ for (int r = 0; r < row; r++) { for (int c = 0; c < col; c++) { diff --git a/mindspore/lite/nnacl/int8/matmul_int8.h b/mindspore/lite/nnacl/int8/matmul_int8.h index 5babb3b443..0092bdcb7e 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.h +++ b/mindspore/lite/nnacl/int8/matmul_int8.h @@ -39,7 +39,7 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col); void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, - bool per_channel); + size_t per_channel); void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); void RowMajor2Row4x8MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16); @@ -59,8 +59,8 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, #ifdef ENABLE_ARM64 void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, - const int *bias, int act_min, int act_max, int out_zp, int multiplier, int left_shift, - int right_shift, int row, int col, int stride); + const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, + int32_t *right_shift, int row, int col, int stride, int filter_peroc); void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, const int *input_sum, const int *bias); diff --git a/mindspore/lite/nnacl/matmul_parameter.h b/mindspore/lite/nnacl/matmul_parameter.h index 8f6b562974..0f89aa6317 100644 --- a/mindspore/lite/nnacl/matmul_parameter.h +++ b/mindspore/lite/nnacl/matmul_parameter.h @@ -25,7 +25,7 @@ typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, i typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, - int32_t maxi, bool per_channel); + int32_t maxi, size_t per_channel); typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col); diff --git a/mindspore/lite/nnacl/opt_op_handler.c b/mindspore/lite/nnacl/opt_op_handler.c index 294f6af837..7a16c16990 100644 --- a/mindspore/lite/nnacl/opt_op_handler.c +++ b/mindspore/lite/nnacl/opt_op_handler.c @@ -30,8 +30,9 @@ extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_ extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, const int *input_sum, const int *bias); extern void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4, - const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int multiplier, - int left_shift, int right_shift, int row, int col, int stride); + const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, + int *multiplier, int *left_shift, int *right_shift, int row, int col, int stride, + size_t peroc); #ifdef __cplusplus } @@ -55,8 +56,8 @@ void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, i void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, - int32_t maxi, bool per_channel) { + int32_t maxi, size_t per_channel) { return MatmulInt8DpNeon64(a, b, dst, UP_ROUND(row, 8), UP_ROUND(col, 8), deep_4, input_sum, bias, mini, maxi, - output_zp, multiplier[0], left_shift[0], right_shift[0], row, col, stride); + output_zp, multiplier, left_shift, right_shift, row, col, stride, per_channel); } #endif diff --git a/mindspore/lite/nnacl/pack.c b/mindspore/lite/nnacl/pack.c index ce4b7b2c59..e142a176f5 100644 --- a/mindspore/lite/nnacl/pack.c +++ b/mindspore/lite/nnacl/pack.c @@ -273,13 +273,13 @@ void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, i "mov x11, %[input_sum] \n" "mov x15, %[filter_zp_ptr] \n" - "mov x0, #0 \n" // row 4 count + "mov x0, #0 \n" "1: \n" "cmp x0, %[hw4] \n" "beq 11f \n" "add x0, x0, #4\n" "dup v10.4s, wzr \n" - "mov x2, #0 \n" // input deep count + "mov x2, #0 \n" "mov x16, x15 \n" "2: \n" @@ -313,9 +313,9 @@ void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, i "b 2b \n" "3: \n" - "mov x12, x11 \n" // tmp inputsm inputsum hw + "mov x12, x11 \n" "add x11, x11, #64 \n" - "mov x4, #0 \n" // oc count + "mov x4, #0 \n" "dup v1.4s, v10.s[0] \n" "dup v2.4s, v10.s[1] \n" diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc index 80147822b2..2781e2e6f4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc @@ -46,6 +46,18 @@ Convolution1x1Int8CPUKernel::~Convolution1x1Int8CPUKernel() { free(filter_zp_ptr_); filter_zp_ptr_ = nullptr; } + if (filter_peroc_ && left_shift_ != nullptr) { + free(left_shift_); + left_shift_ = nullptr; + } + if (filter_peroc_ && right_shift_ != nullptr) { + free(right_shift_); + right_shift_ = nullptr; + } + if (filter_peroc_ && multiplier_ != nullptr) { + free(multiplier_); + multiplier_ = nullptr; + } FreeResizeBuf(); FreeQuantParam(); } @@ -59,7 +71,7 @@ void Convolution1x1Int8CPUKernel::FreeResizeBuf() { } void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { - support_optimize_ = false; + support_optimize_ = true; matmul_func_ = MatMulInt8_8x8_r; #ifdef ENABLE_ARM64 void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; @@ -78,10 +90,6 @@ void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { support_optimize_ = false; matmul_func_ = nullptr; } - - if (filter_peroc_) { - support_optimize_ = false; - } #endif return; } @@ -109,6 +117,26 @@ int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channe for (int fi = 0; fi < output_channel; fi++) { filter_zp_ptr_[fi] = conv_param_->conv_quant_arg_.filter_quant_args_[fi].zp_; } + + int up_round_oc_size = support_optimize_ ? UP_ROUND(output_channel, C8NUM) : UP_ROUND(output_channel, C4NUM); + left_shift_ = reinterpret_cast(malloc(up_round_oc_size * sizeof(int32_t))); + if (left_shift_ == nullptr) { + return RET_ERROR; + } + memset(left_shift_, 0, up_round_oc_size * sizeof(int32_t)); + memcpy(left_shift_, conv_param_->conv_quant_arg_.left_shift_, output_channel * sizeof(int32_t)); + right_shift_ = reinterpret_cast(malloc(up_round_oc_size * sizeof(int32_t))); + if (right_shift_ == nullptr) { + return RET_ERROR; + } + memset(right_shift_, 0, up_round_oc_size * sizeof(int32_t)); + memcpy(right_shift_, conv_param_->conv_quant_arg_.right_shift_, output_channel * sizeof(int32_t)); + multiplier_ = reinterpret_cast(malloc(up_round_oc_size * sizeof(int32_t))); + if (multiplier_ == nullptr) { + return RET_ERROR; + } + memset(multiplier_, 0, up_round_oc_size * sizeof(int32_t)); + memcpy(multiplier_, conv_param_->conv_quant_arg_.quant_multiplier_, output_channel * sizeof(int32_t)); } return RET_OK; } @@ -328,9 +356,9 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { } if (filter_peroc_) { cur_input_sum = input_sum_ + task_id * matmul_param_->row_8_ * thread_stride_ * C8NUM; - cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C8NUM; - cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C8NUM; - cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C8NUM; + cur_left_shift = left_shift_ + task_id * thread_stride_ * C8NUM; + cur_right_shift = right_shift_ + task_id * thread_stride_ * C8NUM; + cur_multiplier = multiplier_ + task_id * thread_stride_ * C8NUM; } Conv1x1Int8Opt(packed_input_, packed_weight_ + task_id * thread_stride_ * C8NUM * matmul_param_->deep_4_, output_ptr_ + task_id * thread_stride_ * C8NUM, cur_input_sum, @@ -346,9 +374,9 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { } if (filter_peroc_) { cur_input_sum = input_sum_ + task_id * matmul_param_->row_4_ * thread_stride_ * C4NUM; - cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C4NUM; - cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C4NUM; - cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C4NUM; + cur_left_shift = left_shift_ + task_id * thread_stride_ * C4NUM; + cur_right_shift = right_shift_ + task_id * thread_stride_ * C4NUM; + cur_multiplier = multiplier_ + task_id * thread_stride_ * C4NUM; } Conv1x1Int8(packed_input_, packed_weight_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_, output_ptr_ + task_id * thread_stride_ * C4NUM, cur_input_sum, diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h index 42ca9b972c..96f6a11f09 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h @@ -58,8 +58,11 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel { int InitBiasByzp(void *src_weight, int input_channel, int output_channel); private: - int32_t *input_sum_ = nullptr; /* per-channel: oc4 format */ - int32_t *filter_zp_ptr_ = nullptr; /* oc - per - channel */ + int32_t *input_sum_ = nullptr; /* per-oc: oc4 format */ + int32_t *filter_zp_ptr_ = nullptr; /* per-oc */ + int32_t *left_shift_ = nullptr; /* per-oc up round */ + int32_t *right_shift_ = nullptr; /* per-oc up round */ + int32_t *multiplier_ = nullptr; /* per-oc up round */ int8_t *packed_weight_ = nullptr; int8_t *packed_input_ = nullptr; int8_t *input_ptr_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc index 80a2d5f710..cac0c1fb8c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc @@ -108,8 +108,8 @@ int FullconnectionInt8CPUKernel::RunImpl(int task_id) { auto cur_c = output_ptr + task_id * thread_stride_ * C4NUM; #ifdef ENABLE_ARM64 MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, q.out_act_min, - q.out_act_max, q.output.zp_, q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res, - p->col_ * sizeof(int8_t)); + q.out_act_max, q.output.zp_, &q.quant_multiplier, &q.left_shift, &q.right_shift, p->row_, cur_oc_res, + p->col_ * sizeof(int8_t), 0); #else MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, q.out_act_min, q.out_act_max, q.output.zp_, q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res, d16_, p->col_); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc index 09059d7624..67f403c217 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc @@ -101,8 +101,8 @@ int MatmulInt8CPUKernel::RunImpl(int task_id) { auto &p = quant_params_; #ifdef ENABLE_ARM64 MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, INT8_MIN, INT8_MAX, - p.output.zp_, p.quant_multiplier, p.left_shift, p.right_shift, params_->row_, cur_oc_res, - params_->col_ * sizeof(int8_t)); + p.output.zp_, &p.quant_multiplier, &p.left_shift, &p.right_shift, params_->row_, cur_oc_res, + params_->col_ * sizeof(int8_t), false); #else MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, INT8_MIN, INT8_MAX, p.output.zp_, p.quant_multiplier, p.left_shift, p.right_shift, params_->row_, cur_oc_res, d16_, params_->col_); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc index 9c53eaaea5..e9abc6e291 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc @@ -120,8 +120,8 @@ TEST_F(TestMatmulInt8, simple) { int multiplier, ls, rs; QuantizeRoundParameter(1.0f, &multiplier, &ls, &rs); #ifdef ENABLE_ARM64 - MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, multiplier, ls, - rs, ROW, COL, COL); + MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, &multiplier, &ls, + &rs, ROW, COL, COL, false); #else MatmulInt8(a_r4x16, b_c16x4, output, a_sums, bias, INT8_MIN, INT8_MAX, 0, multiplier, ls, rs, ROW, COL, DEPTH16, COL); #endif