diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/IndirectGemmInt8_4x4.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/IndirectGemmInt8_4x4.S index 5ac9a0f365..37711436d9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/IndirectGemmInt8_4x4.S +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/IndirectGemmInt8_4x4.S @@ -8,8 +8,8 @@ #endif // void IndirectGemmInt8_4x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, -// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, -// size_t shift_before, size_t shift_after); +// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, +// int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel); // x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset IndirectGemmInt8_4x4: @@ -36,18 +36,26 @@ IndirectGemmInt8_4x4: // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers // r19 ~ r29 should be also preserved // whereas our coding style do not permit such amount of parameters - sub sp, sp, #144 + sub sp, sp, #176 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + stp x23, x24, [sp], #16 ldr x15, [sp] ldr w8, [sp, #8] ldr w9, [sp, #16] ldr w16, [sp, #24] - ldr w17, [sp, #32] - ldr w18, [sp, #40] - ldr w19, [sp, #48] + ldr x17, [sp, #32] + ldr x18, [sp, #40] + ldr x19, [sp, #48] + ldr x20, [sp, #56] + ldr x21, [sp, #64] + + add x24, x6, #3 + mov x23, #4 + sdiv x23, x24, x23 mul x5, x4, x5 mov x4, #1 @@ -189,12 +197,6 @@ IndirectGemmInt8_4x4: sadalp v30.4s, v14.8h sadalp v31.4s, v15.8h - // load sum - mov x20, x15 - ld1r {v8.4s}, [x20], #4 - ld1r {v9.4s}, [x20], #4 - ld1r {v10.4s}, [x20], #4 - ld1r {v11.4s}, [x20] // pairwise add addp v16.4s, v16.4s, v17.4s addp v18.4s, v18.4s, v19.4s @@ -212,28 +214,51 @@ IndirectGemmInt8_4x4: addp v20.4s, v20.4s, v22.4s addp v24.4s, v24.4s, v26.4s addp v28.4s, v28.4s, v30.4s + cbz x20, NoSum + // load sum + mov x22, x15 + cbz x21, SymSum + ld1r {v8.4s}, [x22], x23 + ld1r {v9.4s}, [x22], x23 + ld1r {v10.4s}, [x22], x23 + ld1r {v11.4s}, [x22] + b AddSum + SymSum: + ld1r {v8.4s}, [x22], #4 + ld1r {v9.4s}, [x22], #4 + ld1r {v10.4s}, [x22], #4 + ld1r {v11.4s}, [x22] + AddSum: sub v16.4s, v16.4s, v8.4s sub v20.4s, v20.4s, v9.4s sub v24.4s, v24.4s, v10.4s sub v28.4s, v28.4s, v11.4s + NoSum: add v16.4s, v16.4s, v12.4s add v20.4s, v20.4s, v12.4s add v24.4s, v24.4s, v12.4s add v28.4s, v28.4s, v12.4s - dup v2.4s, w18 + cbnz x21, PerChannel + ld1r {v2.4s}, [x18] + ld1r {v3.4s}, [x17] + ld1r {v4.4s}, [x19] + b QuantizeStart + PerChannel: + ld1 {v2.4s}, [x18] + ld1 {v3.4s}, [x17] + ld1 {v4.4s}, [x19] + QuantizeStart: sqshl v16.4s, v16.4s, v2.4s sqshl v20.4s, v20.4s, v2.4s sqshl v24.4s, v24.4s, v2.4s sqshl v28.4s, v28.4s, v2.4s - dup v3.4s, w17 sqrdmulh v16.4s, v16.4s, v3.4s sqrdmulh v20.4s, v20.4s, v3.4s sqrdmulh v24.4s, v24.4s, v3.4s sqrdmulh v28.4s, v28.4s, v3.4s - dup v4.4s, w19 and v0.16b, v4.16b, v16.16b sshr v0.4s, v0.4s, #31 sqadd v16.4s, v16.4s, v0.4s @@ -325,15 +350,25 @@ IndirectGemmInt8_4x4: bne LoopKsize subs x6, x6, #4 + cbz x21, NoChannelForward + cbz x20, NoSumForward + add x15, x15, #16 + NoSumForward: + add x17, x17, #16 + add x18, x18, #16 + add x19, x19, #16 + NoChannelForward: cbz x3, NoStepFowrard add x3, x3, #16 NoStepFowrard: bgt LoopOc - sub sp, sp, #144 + sub sp, sp, #176 ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 ret #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/IndirectGemmInt8_24x4_dp.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/IndirectGemmInt8_24x4_dp.S index 7e57c022a3..ae4c07fbeb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/IndirectGemmInt8_24x4_dp.S +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/IndirectGemmInt8_24x4_dp.S @@ -8,8 +8,8 @@ #endif // void IndirectGemmInt8_24x4_dp(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, -// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, -// size_t shift_before, size_t shift_after); +// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier, +// int32_t *shift_before, int32_t *shift_after); // x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset // we use sdot intrinsic on cores that supports dotprod(Armv8.2-A w/dp or later) // mrs intrinsic could read system register ID_AA64ISAR0_EL1(or s3_0_c0_c6_0 on Armv8.2-A) @@ -17,35 +17,64 @@ IndirectGemmInt8_24x4_dp: .macro INIT_BIAS - mov x20, x15 - ld1r {v8.4s}, [x20], #4 - ld1r {v9.4s}, [x20], #4 - ld1r {v10.4s}, [x20], #4 - ld1r {v11.4s}, [x20], #4 - ld1r {v12.4s}, [x20], #4 - ld1r {v13.4s}, [x20], #4 - ld1r {v14.4s}, [x20], #4 - ld1r {v15.4s}, [x20], #4 - ld1r {v16.4s}, [x20], #4 - ld1r {v17.4s}, [x20], #4 - ld1r {v18.4s}, [x20], #4 - ld1r {v19.4s}, [x20], #4 - ld1r {v20.4s}, [x20], #4 - ld1r {v21.4s}, [x20], #4 - ld1r {v22.4s}, [x20], #4 - ld1r {v23.4s}, [x20], #4 - ld1r {v24.4s}, [x20], #4 - ld1r {v25.4s}, [x20], #4 - ld1r {v26.4s}, [x20], #4 - ld1r {v27.4s}, [x20], #4 - ld1r {v28.4s}, [x20], #4 - ld1r {v29.4s}, [x20], #4 - ld1r {v30.4s}, [x20], #4 - ld1r {v31.4s}, [x20], #4 dup v7.4s, wzr cbz x3, InitBias ld1 {v7.4s}, [x3] InitBias: + cbz x20, NoSum + mov x22, x15 + cbz x21, SymSum + ld1 {v8.4s}, [x22], x23 + ld1 {v9.4s}, [x22], x23 + ld1 {v10.4s}, [x22], x23 + ld1 {v11.4s}, [x22], x23 + ld1 {v12.4s}, [x22], x23 + ld1 {v13.4s}, [x22], x23 + ld1 {v14.4s}, [x22], x23 + ld1 {v15.4s}, [x22], x23 + ld1 {v16.4s}, [x22], x23 + ld1 {v17.4s}, [x22], x23 + ld1 {v18.4s}, [x22], x23 + ld1 {v19.4s}, [x22], x23 + ld1{v20.4s}, [x22], x23 + ld1 {v21.4s}, [x22], x23 + ld1 {v22.4s}, [x22], x23 + ld1 {v23.4s}, [x22], x23 + ld1 {v24.4s}, [x22], x23 + ld1 {v25.4s}, [x22], x23 + ld1 {v26.4s}, [x22], x23 + ld1 {v27.4s}, [x22], x23 + ld1 {v28.4s}, [x22], x23 + ld1 {v29.4s}, [x22], x23 + ld1 {v30.4s}, [x22], x23 + ld1 {v31.4s}, [x22], x23 + b AddSum + SymSum: + ld1r {v8.4s}, [x22], #4 + ld1r {v9.4s}, [x22], #4 + ld1r {v10.4s}, [x22], #4 + ld1r {v11.4s}, [x22], #4 + ld1r {v12.4s}, [x22], #4 + ld1r {v13.4s}, [x22], #4 + ld1r {v14.4s}, [x22], #4 + ld1r {v15.4s}, [x22], #4 + ld1r {v16.4s}, [x22], #4 + ld1r {v17.4s}, [x22], #4 + ld1r {v18.4s}, [x22], #4 + ld1r {v19.4s}, [x22], #4 + ld1r {v20.4s}, [x22], #4 + ld1r {v21.4s}, [x22], #4 + ld1r {v22.4s}, [x22], #4 + ld1r {v23.4s}, [x22], #4 + ld1r {v24.4s}, [x22], #4 + ld1r {v25.4s}, [x22], #4 + ld1r {v26.4s}, [x22], #4 + ld1r {v27.4s}, [x22], #4 + ld1r {v28.4s}, [x22], #4 + ld1r {v29.4s}, [x22], #4 + ld1r {v30.4s}, [x22], #4 + ld1r {v31.4s}, [x22], #4 + AddSum: sub v8.4s, v7.4s, v8.4s sub v9.4s, v7.4s, v9.4s sub v10.4s, v7.4s, v10.4s @@ -70,24 +99,59 @@ IndirectGemmInt8_24x4_dp: sub v29.4s, v7.4s, v29.4s sub v30.4s, v7.4s, v30.4s sub v31.4s, v7.4s, v31.4s + b InitBiasEnd + NoSum: + mov v8.16b, v7.16b + mov v9.16b, v7.16b + mov v10.16b, v7.16b + mov v11.16b, v7.16b + mov v12.16b, v7.16b + mov v13.16b, v7.16b + mov v14.16b, v7.16b + mov v15.16b, v7.16b + mov v16.16b, v7.16b + mov v17.16b, v7.16b + mov v18.16b, v7.16b + mov v19.16b, v7.16b + mov v20.16b, v7.16b + mov v21.16b, v7.16b + mov v22.16b, v7.16b + mov v23.16b, v7.16b + mov v24.16b, v7.16b + mov v25.16b, v7.16b + mov v26.16b, v7.16b + mov v27.16b, v7.16b + mov v28.16b, v7.16b + mov v29.16b, v7.16b + mov v30.16b, v7.16b + mov v31.16b, v7.16b + InitBiasEnd: .endm // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers // r19 ~ r29 should be also preserved // whereas our coding style do not permit such amount of parameters - sub sp, sp, #144 + sub sp, sp, #176 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + stp x23, x24, [sp], #16 ldr x15, [sp] ldr w8, [sp, #8] ldr w9, [sp, #16] ldr w16, [sp, #24] - ldr w17, [sp, #32] - ldr w18, [sp, #40] - ldr w19, [sp, #48] + ldr x17, [sp, #32] + ldr x18, [sp, #40] + ldr x19, [sp, #48] + ldr x20, [sp, #56] + ldr x21, [sp, #64] + + add x24, x6, #3 + mov x23, #4 + sdiv x23, x24, x23 mul x5, x4, x5 mov x4, #1 @@ -206,7 +270,7 @@ IndirectGemmInt8_24x4_dp: b LoopIc LoopIcEnd: - mov x20, x15 + mov x22, x15 // load input for output 1-8 ld1 {v0.16b, v1.16b}, [x12], #32 .inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0] @@ -276,7 +340,16 @@ IndirectGemmInt8_24x4_dp: .inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3] Quantization: - dup v2.4s, w18 + cbnz x21, PerChannel + ld1r {v2.4s}, [x18] + ld1r {v3.4s}, [x17] + ld1r {v4.4s}, [x19] + b QuantizeStart + PerChannel: + ld1 {v2.4s}, [x18] + ld1 {v3.4s}, [x17] + ld1 {v4.4s}, [x19] + QuantizeStart: sqshl v8.4s, v8.4s, v2.4s sqshl v9.4s, v9.4s, v2.4s sqshl v10.4s, v10.4s, v2.4s @@ -302,7 +375,6 @@ IndirectGemmInt8_24x4_dp: sqshl v30.4s, v30.4s, v2.4s sqshl v31.4s, v31.4s, v2.4s - dup v3.4s, w17 sqrdmulh v8.4s, v8.4s, v3.4s sqrdmulh v9.4s, v9.4s, v3.4s sqrdmulh v10.4s, v10.4s, v3.4s @@ -328,100 +400,99 @@ IndirectGemmInt8_24x4_dp: sqrdmulh v30.4s, v30.4s, v3.4s sqrdmulh v31.4s, v31.4s, v3.4s - dup v4.4s, w19 - add v0.16b, v4.16b, v8.16b + and v0.16b, v4.16b, v8.16b sshr v0.4s, v0.4s, #31 sqadd v8.4s, v8.4s, v0.4s srshl v8.4s, v8.4s, v4.4s - add v0.16b, v4.16b, v9.16b + and v0.16b, v4.16b, v9.16b sshr v1.4s, v1.4s, #31 sqadd v9.4s, v9.4s, v1.4s srshl v9.4s, v9.4s, v4.4s - add v2.16b, v4.16b, v10.16b + and v2.16b, v4.16b, v10.16b sshr v2.4s, v2.4s, #31 sqadd v10.4s, v10.4s, v2.4s srshl v10.4s, v10.4s, v4.4s - add v3.16b, v4.16b, v11.16b + and v3.16b, v4.16b, v11.16b sshr v3.4s, v3.4s, #31 sqadd v11.4s, v11.4s, v3.4s srshl v11.4s, v11.4s, v4.4s - add v0.16b, v4.16b, v12.16b + and v0.16b, v4.16b, v12.16b sshr v0.4s, v0.4s, #31 sqadd v12.4s, v12.4s, v0.4s srshl v12.4s, v12.4s, v4.4s - add v0.16b, v4.16b, v13.16b + and v0.16b, v4.16b, v13.16b sshr v1.4s, v1.4s, #31 sqadd v13.4s, v13.4s, v1.4s srshl v13.4s, v13.4s, v4.4s - add v2.16b, v4.16b, v14.16b + and v2.16b, v4.16b, v14.16b sshr v2.4s, v2.4s, #31 sqadd v14.4s, v14.4s, v2.4s srshl v14.4s, v14.4s, v4.4s - add v3.16b, v4.16b, v15.16b + and v3.16b, v4.16b, v15.16b sshr v3.4s, v3.4s, #31 sqadd v15.4s, v15.4s, v3.4s srshl v15.4s, v15.4s, v4.4s - add v0.16b, v4.16b, v16.16b + and v0.16b, v4.16b, v16.16b sshr v0.4s, v0.4s, #31 sqadd v16.4s, v16.4s, v0.4s srshl v16.4s, v16.4s, v4.4s - add v0.16b, v4.16b, v17.16b + and v0.16b, v4.16b, v17.16b sshr v1.4s, v1.4s, #31 sqadd v17.4s, v17.4s, v1.4s srshl v17.4s, v17.4s, v4.4s - add v2.16b, v4.16b, v18.16b + and v2.16b, v4.16b, v18.16b sshr v2.4s, v2.4s, #31 sqadd v18.4s, v18.4s, v2.4s srshl v18.4s, v18.4s, v4.4s - add v3.16b, v4.16b, v19.16b + and v3.16b, v4.16b, v19.16b sshr v3.4s, v3.4s, #31 sqadd v19.4s, v19.4s, v3.4s srshl v19.4s, v19.4s, v4.4s - add v0.16b, v4.16b, v20.16b + and v0.16b, v4.16b, v20.16b sshr v0.4s, v0.4s, #31 sqadd v20.4s, v20.4s, v0.4s srshl v20.4s, v20.4s, v4.4s - add v0.16b, v4.16b, v21.16b + and v0.16b, v4.16b, v21.16b sshr v1.4s, v1.4s, #31 sqadd v21.4s, v21.4s, v1.4s srshl v21.4s, v21.4s, v4.4s - add v2.16b, v4.16b, v22.16b + and v2.16b, v4.16b, v22.16b sshr v2.4s, v2.4s, #31 sqadd v22.4s, v22.4s, v2.4s - srshl v10.4s, v10.4s, v4.4s - add v3.16b, v4.16b, v23.16b + srshl v22.4s, v22.4s, v4.4s + and v3.16b, v4.16b, v23.16b sshr v3.4s, v3.4s, #31 sqadd v23.4s, v23.4s, v3.4s srshl v23.4s, v23.4s, v4.4s - add v0.16b, v4.16b, v24.16b + and v0.16b, v4.16b, v24.16b sshr v0.4s, v0.4s, #31 sqadd v24.4s, v24.4s, v0.4s srshl v24.4s, v24.4s, v4.4s - add v0.16b, v4.16b, v25.16b + and v0.16b, v4.16b, v25.16b sshr v1.4s, v1.4s, #31 sqadd v25.4s, v25.4s, v1.4s srshl v25.4s, v25.4s, v4.4s - add v2.16b, v4.16b, v26.16b + and v2.16b, v4.16b, v26.16b sshr v2.4s, v2.4s, #31 sqadd v26.4s, v26.4s, v2.4s srshl v26.4s, v26.4s, v4.4s - add v3.16b, v4.16b, v27.16b + and v3.16b, v4.16b, v27.16b sshr v3.4s, v3.4s, #31 sqadd v27.4s, v27.4s, v3.4s srshl v27.4s, v27.4s, v4.4s - add v0.16b, v4.16b, v28.16b + and v0.16b, v4.16b, v28.16b sshr v0.4s, v0.4s, #31 sqadd v28.4s, v28.4s, v0.4s srshl v28.4s, v28.4s, v4.4s - add v0.16b, v4.16b, v29.16b + and v0.16b, v4.16b, v29.16b sshr v1.4s, v1.4s, #31 sqadd v29.4s, v29.4s, v1.4s srshl v29.4s, v29.4s, v4.4s - add v2.16b, v4.16b, v30.16b + and v2.16b, v4.16b, v30.16b sshr v2.4s, v2.4s, #31 sqadd v30.4s, v30.4s, v2.4s srshl v30.4s, v30.4s, v4.4s - add v3.16b, v4.16b, v31.16b + and v3.16b, v4.16b, v31.16b sshr v3.4s, v3.4s, #31 sqadd v31.4s, v31.4s, v3.4s srshl v31.4s, v31.4s, v4.4s @@ -694,15 +765,24 @@ IndirectGemmInt8_24x4_dp: bne LoopKsize subs x6, x6, #4 + cbz x21, NoChannelForward + cbz x20, NoSumForward + add x15, x15, #16 + NoSumForward: + add x17, x17, #16 + add x18, x18, #16 + add x19, x19, #16 + NoChannelForward: cbz x3, NoStepFowrard add x3, x3, #16 NoStepFowrard: bgt LoopOc - sub sp, sp, #144 + sub sp, sp, #176 ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 ret #endif - diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.c index 6500ef8b6e..4c130ba995 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.c @@ -16,7 +16,7 @@ #include "nnacl/fp32/common_func.h" -#ifndef __aarch64__ +#ifndef ENABLE_ARM64 void MatrixAdd(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride, size_t row, size_t col) { for (int r = 0; r < row; r++) { diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/common_func.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/common_func.h index cb680fc123..a5b8aa2f8b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/common_func.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/common_func.h @@ -40,8 +40,8 @@ void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t * size_t oc4, size_t offset); void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize, size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min, - size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before, - size_t shift_after); + size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before, + int32_t *shift_after, size_t asymmetric, size_t per_channel); void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.c index f8d19cf6f3..cbee7f19c2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.c @@ -29,14 +29,12 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0]; int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0]; int oc4 = UP_DIV(output_channel, C4NUM); -#ifdef __aarch64__ +#ifdef ENABLE_ARM64 + size_t asymmetric = conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC; + size_t per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; IndirectGemmInt8_4x4(dst, src, weight, bias, UP_DIV(kernel_plane, C4NUM), ic4, output_channel, output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier, - shift_before, shift_after); -// #elif defined(ENABLE_ARM32) -// IndirectGemmInt8_2x4(dst, src, weight, bias, UP_DIV(kernel_plane, C4NUM), ic4, output_channel, -// output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier, -// shift_before, shift_after); + shift_before, shift_after, asymmetric, per_channel); #else int tile_num = conv_param->tile_num_; int plane_c4 = UP_DIV(kernel_plane, C4NUM); @@ -124,8 +122,10 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const int oc4 = UP_DIV(output_channel, C4NUM); if (gemm_func != NULL) { #ifdef __aarch64__ + size_t asymmetric = conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC; + size_t per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; gemm_func(dst, src, weight, bias, kernel_plane, ic4, output_channel, output_channel * sizeof(int8_t), input_sum, - act_min, act_max, out_zp, out_multiplier, shift_before, shift_after); + act_min, act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel); #endif } else { int tile_num = conv_param->tile_num_; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.h index 45332b4311..ff2b9f39b0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.h @@ -28,8 +28,8 @@ typedef void (*GEMM_FUNC)(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t ksize, size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, size_t act_min, - size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before, - size_t shift_after); + size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before, + int32_t *shift_after, size_t asymmetric, size_t per_channel); #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/opt_op_handler.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/opt_op_handler.c index 47149e6653..14d6309f17 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/opt_op_handler.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/opt_op_handler.c @@ -22,11 +22,11 @@ extern "C" { extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t ksize, size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, - size_t out_multiplier, size_t shift_before, size_t shift_after); + int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after, + size_t asymmetric, size_t per_channel); extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, const int *input_sum, const int *bias); - #ifdef __cplusplus } #endif @@ -35,9 +35,10 @@ extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, in void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t ksize, size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, - size_t out_multiplier, size_t shift_before, size_t shift_after) { + int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after, + size_t asymmetric, size_t per_channel) { return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min, - act_max, out_zp, out_multiplier, shift_before, shift_after); + act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel); } void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.c index 834634e53b..f0ad7cbaaf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.c @@ -879,8 +879,8 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int const float *src_ptr = src_batch + hw * channel + c; float *dst_ptr = dst_batch + c * plane + hw; #ifdef ENABLE_ARM64 - int srcStride = channel * 4; - int dstStride = plane * 4; + size_t srcStride = channel * sizeof(float); + size_t dstStride = plane * sizeof(float); asm volatile( "mov x10, %[src_ptr]\n" "mov x11, %[dst_ptr]\n"