| @@ -8,8 +8,8 @@ | |||||
| #endif | #endif | ||||
| // void IndirectGemmInt8_4x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, | // void IndirectGemmInt8_4x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, | ||||
| // size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, | |||||
| // size_t shift_before, size_t shift_after); | |||||
| // size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, | |||||
| // int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel); | |||||
| // x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset | // x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset | ||||
| IndirectGemmInt8_4x4: | IndirectGemmInt8_4x4: | ||||
| @@ -36,18 +36,26 @@ IndirectGemmInt8_4x4: | |||||
| // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | ||||
| // r19 ~ r29 should be also preserved | // r19 ~ r29 should be also preserved | ||||
| // whereas our coding style do not permit such amount of parameters | // whereas our coding style do not permit such amount of parameters | ||||
| sub sp, sp, #144 | |||||
| sub sp, sp, #176 | |||||
| st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | ||||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | ||||
| stp x19, x20, [sp], #16 | stp x19, x20, [sp], #16 | ||||
| stp x21, x22, [sp], #16 | |||||
| stp x23, x24, [sp], #16 | |||||
| ldr x15, [sp] | ldr x15, [sp] | ||||
| ldr w8, [sp, #8] | ldr w8, [sp, #8] | ||||
| ldr w9, [sp, #16] | ldr w9, [sp, #16] | ||||
| ldr w16, [sp, #24] | ldr w16, [sp, #24] | ||||
| ldr w17, [sp, #32] | |||||
| ldr w18, [sp, #40] | |||||
| ldr w19, [sp, #48] | |||||
| ldr x17, [sp, #32] | |||||
| ldr x18, [sp, #40] | |||||
| ldr x19, [sp, #48] | |||||
| ldr x20, [sp, #56] | |||||
| ldr x21, [sp, #64] | |||||
| add x24, x6, #3 | |||||
| mov x23, #4 | |||||
| sdiv x23, x24, x23 | |||||
| mul x5, x4, x5 | mul x5, x4, x5 | ||||
| mov x4, #1 | mov x4, #1 | ||||
| @@ -189,12 +197,6 @@ IndirectGemmInt8_4x4: | |||||
| sadalp v30.4s, v14.8h | sadalp v30.4s, v14.8h | ||||
| sadalp v31.4s, v15.8h | sadalp v31.4s, v15.8h | ||||
| // load sum | |||||
| mov x20, x15 | |||||
| ld1r {v8.4s}, [x20], #4 | |||||
| ld1r {v9.4s}, [x20], #4 | |||||
| ld1r {v10.4s}, [x20], #4 | |||||
| ld1r {v11.4s}, [x20] | |||||
| // pairwise add | // pairwise add | ||||
| addp v16.4s, v16.4s, v17.4s | addp v16.4s, v16.4s, v17.4s | ||||
| addp v18.4s, v18.4s, v19.4s | addp v18.4s, v18.4s, v19.4s | ||||
| @@ -212,28 +214,51 @@ IndirectGemmInt8_4x4: | |||||
| addp v20.4s, v20.4s, v22.4s | addp v20.4s, v20.4s, v22.4s | ||||
| addp v24.4s, v24.4s, v26.4s | addp v24.4s, v24.4s, v26.4s | ||||
| addp v28.4s, v28.4s, v30.4s | addp v28.4s, v28.4s, v30.4s | ||||
| cbz x20, NoSum | |||||
| // load sum | |||||
| mov x22, x15 | |||||
| cbz x21, SymSum | |||||
| ld1r {v8.4s}, [x22], x23 | |||||
| ld1r {v9.4s}, [x22], x23 | |||||
| ld1r {v10.4s}, [x22], x23 | |||||
| ld1r {v11.4s}, [x22] | |||||
| b AddSum | |||||
| SymSum: | |||||
| ld1r {v8.4s}, [x22], #4 | |||||
| ld1r {v9.4s}, [x22], #4 | |||||
| ld1r {v10.4s}, [x22], #4 | |||||
| ld1r {v11.4s}, [x22] | |||||
| AddSum: | |||||
| sub v16.4s, v16.4s, v8.4s | sub v16.4s, v16.4s, v8.4s | ||||
| sub v20.4s, v20.4s, v9.4s | sub v20.4s, v20.4s, v9.4s | ||||
| sub v24.4s, v24.4s, v10.4s | sub v24.4s, v24.4s, v10.4s | ||||
| sub v28.4s, v28.4s, v11.4s | sub v28.4s, v28.4s, v11.4s | ||||
| NoSum: | |||||
| add v16.4s, v16.4s, v12.4s | add v16.4s, v16.4s, v12.4s | ||||
| add v20.4s, v20.4s, v12.4s | add v20.4s, v20.4s, v12.4s | ||||
| add v24.4s, v24.4s, v12.4s | add v24.4s, v24.4s, v12.4s | ||||
| add v28.4s, v28.4s, v12.4s | add v28.4s, v28.4s, v12.4s | ||||
| dup v2.4s, w18 | |||||
| cbnz x21, PerChannel | |||||
| ld1r {v2.4s}, [x18] | |||||
| ld1r {v3.4s}, [x17] | |||||
| ld1r {v4.4s}, [x19] | |||||
| b QuantizeStart | |||||
| PerChannel: | |||||
| ld1 {v2.4s}, [x18] | |||||
| ld1 {v3.4s}, [x17] | |||||
| ld1 {v4.4s}, [x19] | |||||
| QuantizeStart: | |||||
| sqshl v16.4s, v16.4s, v2.4s | sqshl v16.4s, v16.4s, v2.4s | ||||
| sqshl v20.4s, v20.4s, v2.4s | sqshl v20.4s, v20.4s, v2.4s | ||||
| sqshl v24.4s, v24.4s, v2.4s | sqshl v24.4s, v24.4s, v2.4s | ||||
| sqshl v28.4s, v28.4s, v2.4s | sqshl v28.4s, v28.4s, v2.4s | ||||
| dup v3.4s, w17 | |||||
| sqrdmulh v16.4s, v16.4s, v3.4s | sqrdmulh v16.4s, v16.4s, v3.4s | ||||
| sqrdmulh v20.4s, v20.4s, v3.4s | sqrdmulh v20.4s, v20.4s, v3.4s | ||||
| sqrdmulh v24.4s, v24.4s, v3.4s | sqrdmulh v24.4s, v24.4s, v3.4s | ||||
| sqrdmulh v28.4s, v28.4s, v3.4s | sqrdmulh v28.4s, v28.4s, v3.4s | ||||
| dup v4.4s, w19 | |||||
| and v0.16b, v4.16b, v16.16b | and v0.16b, v4.16b, v16.16b | ||||
| sshr v0.4s, v0.4s, #31 | sshr v0.4s, v0.4s, #31 | ||||
| sqadd v16.4s, v16.4s, v0.4s | sqadd v16.4s, v16.4s, v0.4s | ||||
| @@ -325,15 +350,25 @@ IndirectGemmInt8_4x4: | |||||
| bne LoopKsize | bne LoopKsize | ||||
| subs x6, x6, #4 | subs x6, x6, #4 | ||||
| cbz x21, NoChannelForward | |||||
| cbz x20, NoSumForward | |||||
| add x15, x15, #16 | |||||
| NoSumForward: | |||||
| add x17, x17, #16 | |||||
| add x18, x18, #16 | |||||
| add x19, x19, #16 | |||||
| NoChannelForward: | |||||
| cbz x3, NoStepFowrard | cbz x3, NoStepFowrard | ||||
| add x3, x3, #16 | add x3, x3, #16 | ||||
| NoStepFowrard: | NoStepFowrard: | ||||
| bgt LoopOc | bgt LoopOc | ||||
| sub sp, sp, #144 | |||||
| sub sp, sp, #176 | |||||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | ||||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | ||||
| ldp x19, x20, [sp], #16 | ldp x19, x20, [sp], #16 | ||||
| ldp x21, x22, [sp], #16 | |||||
| ldp x23, x24, [sp], #16 | |||||
| ret | ret | ||||
| #endif | #endif | ||||
| @@ -8,8 +8,8 @@ | |||||
| #endif | #endif | ||||
| // void IndirectGemmInt8_24x4_dp(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, | // void IndirectGemmInt8_24x4_dp(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, | ||||
| // size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, | |||||
| // size_t shift_before, size_t shift_after); | |||||
| // size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier, | |||||
| // int32_t *shift_before, int32_t *shift_after); | |||||
| // x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset | // x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset | ||||
| // we use sdot intrinsic on cores that supports dotprod(Armv8.2-A w/dp or later) | // we use sdot intrinsic on cores that supports dotprod(Armv8.2-A w/dp or later) | ||||
| // mrs intrinsic could read system register ID_AA64ISAR0_EL1(or s3_0_c0_c6_0 on Armv8.2-A) | // mrs intrinsic could read system register ID_AA64ISAR0_EL1(or s3_0_c0_c6_0 on Armv8.2-A) | ||||
| @@ -17,35 +17,64 @@ | |||||
| IndirectGemmInt8_24x4_dp: | IndirectGemmInt8_24x4_dp: | ||||
| .macro INIT_BIAS | .macro INIT_BIAS | ||||
| mov x20, x15 | |||||
| ld1r {v8.4s}, [x20], #4 | |||||
| ld1r {v9.4s}, [x20], #4 | |||||
| ld1r {v10.4s}, [x20], #4 | |||||
| ld1r {v11.4s}, [x20], #4 | |||||
| ld1r {v12.4s}, [x20], #4 | |||||
| ld1r {v13.4s}, [x20], #4 | |||||
| ld1r {v14.4s}, [x20], #4 | |||||
| ld1r {v15.4s}, [x20], #4 | |||||
| ld1r {v16.4s}, [x20], #4 | |||||
| ld1r {v17.4s}, [x20], #4 | |||||
| ld1r {v18.4s}, [x20], #4 | |||||
| ld1r {v19.4s}, [x20], #4 | |||||
| ld1r {v20.4s}, [x20], #4 | |||||
| ld1r {v21.4s}, [x20], #4 | |||||
| ld1r {v22.4s}, [x20], #4 | |||||
| ld1r {v23.4s}, [x20], #4 | |||||
| ld1r {v24.4s}, [x20], #4 | |||||
| ld1r {v25.4s}, [x20], #4 | |||||
| ld1r {v26.4s}, [x20], #4 | |||||
| ld1r {v27.4s}, [x20], #4 | |||||
| ld1r {v28.4s}, [x20], #4 | |||||
| ld1r {v29.4s}, [x20], #4 | |||||
| ld1r {v30.4s}, [x20], #4 | |||||
| ld1r {v31.4s}, [x20], #4 | |||||
| dup v7.4s, wzr | dup v7.4s, wzr | ||||
| cbz x3, InitBias | cbz x3, InitBias | ||||
| ld1 {v7.4s}, [x3] | ld1 {v7.4s}, [x3] | ||||
| InitBias: | InitBias: | ||||
| cbz x20, NoSum | |||||
| mov x22, x15 | |||||
| cbz x21, SymSum | |||||
| ld1 {v8.4s}, [x22], x23 | |||||
| ld1 {v9.4s}, [x22], x23 | |||||
| ld1 {v10.4s}, [x22], x23 | |||||
| ld1 {v11.4s}, [x22], x23 | |||||
| ld1 {v12.4s}, [x22], x23 | |||||
| ld1 {v13.4s}, [x22], x23 | |||||
| ld1 {v14.4s}, [x22], x23 | |||||
| ld1 {v15.4s}, [x22], x23 | |||||
| ld1 {v16.4s}, [x22], x23 | |||||
| ld1 {v17.4s}, [x22], x23 | |||||
| ld1 {v18.4s}, [x22], x23 | |||||
| ld1 {v19.4s}, [x22], x23 | |||||
| ld1{v20.4s}, [x22], x23 | |||||
| ld1 {v21.4s}, [x22], x23 | |||||
| ld1 {v22.4s}, [x22], x23 | |||||
| ld1 {v23.4s}, [x22], x23 | |||||
| ld1 {v24.4s}, [x22], x23 | |||||
| ld1 {v25.4s}, [x22], x23 | |||||
| ld1 {v26.4s}, [x22], x23 | |||||
| ld1 {v27.4s}, [x22], x23 | |||||
| ld1 {v28.4s}, [x22], x23 | |||||
| ld1 {v29.4s}, [x22], x23 | |||||
| ld1 {v30.4s}, [x22], x23 | |||||
| ld1 {v31.4s}, [x22], x23 | |||||
| b AddSum | |||||
| SymSum: | |||||
| ld1r {v8.4s}, [x22], #4 | |||||
| ld1r {v9.4s}, [x22], #4 | |||||
| ld1r {v10.4s}, [x22], #4 | |||||
| ld1r {v11.4s}, [x22], #4 | |||||
| ld1r {v12.4s}, [x22], #4 | |||||
| ld1r {v13.4s}, [x22], #4 | |||||
| ld1r {v14.4s}, [x22], #4 | |||||
| ld1r {v15.4s}, [x22], #4 | |||||
| ld1r {v16.4s}, [x22], #4 | |||||
| ld1r {v17.4s}, [x22], #4 | |||||
| ld1r {v18.4s}, [x22], #4 | |||||
| ld1r {v19.4s}, [x22], #4 | |||||
| ld1r {v20.4s}, [x22], #4 | |||||
| ld1r {v21.4s}, [x22], #4 | |||||
| ld1r {v22.4s}, [x22], #4 | |||||
| ld1r {v23.4s}, [x22], #4 | |||||
| ld1r {v24.4s}, [x22], #4 | |||||
| ld1r {v25.4s}, [x22], #4 | |||||
| ld1r {v26.4s}, [x22], #4 | |||||
| ld1r {v27.4s}, [x22], #4 | |||||
| ld1r {v28.4s}, [x22], #4 | |||||
| ld1r {v29.4s}, [x22], #4 | |||||
| ld1r {v30.4s}, [x22], #4 | |||||
| ld1r {v31.4s}, [x22], #4 | |||||
| AddSum: | |||||
| sub v8.4s, v7.4s, v8.4s | sub v8.4s, v7.4s, v8.4s | ||||
| sub v9.4s, v7.4s, v9.4s | sub v9.4s, v7.4s, v9.4s | ||||
| sub v10.4s, v7.4s, v10.4s | sub v10.4s, v7.4s, v10.4s | ||||
| @@ -70,24 +99,59 @@ IndirectGemmInt8_24x4_dp: | |||||
| sub v29.4s, v7.4s, v29.4s | sub v29.4s, v7.4s, v29.4s | ||||
| sub v30.4s, v7.4s, v30.4s | sub v30.4s, v7.4s, v30.4s | ||||
| sub v31.4s, v7.4s, v31.4s | sub v31.4s, v7.4s, v31.4s | ||||
| b InitBiasEnd | |||||
| NoSum: | |||||
| mov v8.16b, v7.16b | |||||
| mov v9.16b, v7.16b | |||||
| mov v10.16b, v7.16b | |||||
| mov v11.16b, v7.16b | |||||
| mov v12.16b, v7.16b | |||||
| mov v13.16b, v7.16b | |||||
| mov v14.16b, v7.16b | |||||
| mov v15.16b, v7.16b | |||||
| mov v16.16b, v7.16b | |||||
| mov v17.16b, v7.16b | |||||
| mov v18.16b, v7.16b | |||||
| mov v19.16b, v7.16b | |||||
| mov v20.16b, v7.16b | |||||
| mov v21.16b, v7.16b | |||||
| mov v22.16b, v7.16b | |||||
| mov v23.16b, v7.16b | |||||
| mov v24.16b, v7.16b | |||||
| mov v25.16b, v7.16b | |||||
| mov v26.16b, v7.16b | |||||
| mov v27.16b, v7.16b | |||||
| mov v28.16b, v7.16b | |||||
| mov v29.16b, v7.16b | |||||
| mov v30.16b, v7.16b | |||||
| mov v31.16b, v7.16b | |||||
| InitBiasEnd: | |||||
| .endm | .endm | ||||
| // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to | // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to | ||||
| // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | ||||
| // r19 ~ r29 should be also preserved | // r19 ~ r29 should be also preserved | ||||
| // whereas our coding style do not permit such amount of parameters | // whereas our coding style do not permit such amount of parameters | ||||
| sub sp, sp, #144 | |||||
| sub sp, sp, #176 | |||||
| st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | ||||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | ||||
| stp x19, x20, [sp], #16 | stp x19, x20, [sp], #16 | ||||
| stp x21, x22, [sp], #16 | |||||
| stp x23, x24, [sp], #16 | |||||
| ldr x15, [sp] | ldr x15, [sp] | ||||
| ldr w8, [sp, #8] | ldr w8, [sp, #8] | ||||
| ldr w9, [sp, #16] | ldr w9, [sp, #16] | ||||
| ldr w16, [sp, #24] | ldr w16, [sp, #24] | ||||
| ldr w17, [sp, #32] | |||||
| ldr w18, [sp, #40] | |||||
| ldr w19, [sp, #48] | |||||
| ldr x17, [sp, #32] | |||||
| ldr x18, [sp, #40] | |||||
| ldr x19, [sp, #48] | |||||
| ldr x20, [sp, #56] | |||||
| ldr x21, [sp, #64] | |||||
| add x24, x6, #3 | |||||
| mov x23, #4 | |||||
| sdiv x23, x24, x23 | |||||
| mul x5, x4, x5 | mul x5, x4, x5 | ||||
| mov x4, #1 | mov x4, #1 | ||||
| @@ -206,7 +270,7 @@ IndirectGemmInt8_24x4_dp: | |||||
| b LoopIc | b LoopIc | ||||
| LoopIcEnd: | LoopIcEnd: | ||||
| mov x20, x15 | |||||
| mov x22, x15 | |||||
| // load input for output 1-8 | // load input for output 1-8 | ||||
| ld1 {v0.16b, v1.16b}, [x12], #32 | ld1 {v0.16b, v1.16b}, [x12], #32 | ||||
| .inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0] | .inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0] | ||||
| @@ -276,7 +340,16 @@ IndirectGemmInt8_24x4_dp: | |||||
| .inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3] | .inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3] | ||||
| Quantization: | Quantization: | ||||
| dup v2.4s, w18 | |||||
| cbnz x21, PerChannel | |||||
| ld1r {v2.4s}, [x18] | |||||
| ld1r {v3.4s}, [x17] | |||||
| ld1r {v4.4s}, [x19] | |||||
| b QuantizeStart | |||||
| PerChannel: | |||||
| ld1 {v2.4s}, [x18] | |||||
| ld1 {v3.4s}, [x17] | |||||
| ld1 {v4.4s}, [x19] | |||||
| QuantizeStart: | |||||
| sqshl v8.4s, v8.4s, v2.4s | sqshl v8.4s, v8.4s, v2.4s | ||||
| sqshl v9.4s, v9.4s, v2.4s | sqshl v9.4s, v9.4s, v2.4s | ||||
| sqshl v10.4s, v10.4s, v2.4s | sqshl v10.4s, v10.4s, v2.4s | ||||
| @@ -302,7 +375,6 @@ IndirectGemmInt8_24x4_dp: | |||||
| sqshl v30.4s, v30.4s, v2.4s | sqshl v30.4s, v30.4s, v2.4s | ||||
| sqshl v31.4s, v31.4s, v2.4s | sqshl v31.4s, v31.4s, v2.4s | ||||
| dup v3.4s, w17 | |||||
| sqrdmulh v8.4s, v8.4s, v3.4s | sqrdmulh v8.4s, v8.4s, v3.4s | ||||
| sqrdmulh v9.4s, v9.4s, v3.4s | sqrdmulh v9.4s, v9.4s, v3.4s | ||||
| sqrdmulh v10.4s, v10.4s, v3.4s | sqrdmulh v10.4s, v10.4s, v3.4s | ||||
| @@ -328,100 +400,99 @@ IndirectGemmInt8_24x4_dp: | |||||
| sqrdmulh v30.4s, v30.4s, v3.4s | sqrdmulh v30.4s, v30.4s, v3.4s | ||||
| sqrdmulh v31.4s, v31.4s, v3.4s | sqrdmulh v31.4s, v31.4s, v3.4s | ||||
| dup v4.4s, w19 | |||||
| add v0.16b, v4.16b, v8.16b | |||||
| and v0.16b, v4.16b, v8.16b | |||||
| sshr v0.4s, v0.4s, #31 | sshr v0.4s, v0.4s, #31 | ||||
| sqadd v8.4s, v8.4s, v0.4s | sqadd v8.4s, v8.4s, v0.4s | ||||
| srshl v8.4s, v8.4s, v4.4s | srshl v8.4s, v8.4s, v4.4s | ||||
| add v0.16b, v4.16b, v9.16b | |||||
| and v0.16b, v4.16b, v9.16b | |||||
| sshr v1.4s, v1.4s, #31 | sshr v1.4s, v1.4s, #31 | ||||
| sqadd v9.4s, v9.4s, v1.4s | sqadd v9.4s, v9.4s, v1.4s | ||||
| srshl v9.4s, v9.4s, v4.4s | srshl v9.4s, v9.4s, v4.4s | ||||
| add v2.16b, v4.16b, v10.16b | |||||
| and v2.16b, v4.16b, v10.16b | |||||
| sshr v2.4s, v2.4s, #31 | sshr v2.4s, v2.4s, #31 | ||||
| sqadd v10.4s, v10.4s, v2.4s | sqadd v10.4s, v10.4s, v2.4s | ||||
| srshl v10.4s, v10.4s, v4.4s | srshl v10.4s, v10.4s, v4.4s | ||||
| add v3.16b, v4.16b, v11.16b | |||||
| and v3.16b, v4.16b, v11.16b | |||||
| sshr v3.4s, v3.4s, #31 | sshr v3.4s, v3.4s, #31 | ||||
| sqadd v11.4s, v11.4s, v3.4s | sqadd v11.4s, v11.4s, v3.4s | ||||
| srshl v11.4s, v11.4s, v4.4s | srshl v11.4s, v11.4s, v4.4s | ||||
| add v0.16b, v4.16b, v12.16b | |||||
| and v0.16b, v4.16b, v12.16b | |||||
| sshr v0.4s, v0.4s, #31 | sshr v0.4s, v0.4s, #31 | ||||
| sqadd v12.4s, v12.4s, v0.4s | sqadd v12.4s, v12.4s, v0.4s | ||||
| srshl v12.4s, v12.4s, v4.4s | srshl v12.4s, v12.4s, v4.4s | ||||
| add v0.16b, v4.16b, v13.16b | |||||
| and v0.16b, v4.16b, v13.16b | |||||
| sshr v1.4s, v1.4s, #31 | sshr v1.4s, v1.4s, #31 | ||||
| sqadd v13.4s, v13.4s, v1.4s | sqadd v13.4s, v13.4s, v1.4s | ||||
| srshl v13.4s, v13.4s, v4.4s | srshl v13.4s, v13.4s, v4.4s | ||||
| add v2.16b, v4.16b, v14.16b | |||||
| and v2.16b, v4.16b, v14.16b | |||||
| sshr v2.4s, v2.4s, #31 | sshr v2.4s, v2.4s, #31 | ||||
| sqadd v14.4s, v14.4s, v2.4s | sqadd v14.4s, v14.4s, v2.4s | ||||
| srshl v14.4s, v14.4s, v4.4s | srshl v14.4s, v14.4s, v4.4s | ||||
| add v3.16b, v4.16b, v15.16b | |||||
| and v3.16b, v4.16b, v15.16b | |||||
| sshr v3.4s, v3.4s, #31 | sshr v3.4s, v3.4s, #31 | ||||
| sqadd v15.4s, v15.4s, v3.4s | sqadd v15.4s, v15.4s, v3.4s | ||||
| srshl v15.4s, v15.4s, v4.4s | srshl v15.4s, v15.4s, v4.4s | ||||
| add v0.16b, v4.16b, v16.16b | |||||
| and v0.16b, v4.16b, v16.16b | |||||
| sshr v0.4s, v0.4s, #31 | sshr v0.4s, v0.4s, #31 | ||||
| sqadd v16.4s, v16.4s, v0.4s | sqadd v16.4s, v16.4s, v0.4s | ||||
| srshl v16.4s, v16.4s, v4.4s | srshl v16.4s, v16.4s, v4.4s | ||||
| add v0.16b, v4.16b, v17.16b | |||||
| and v0.16b, v4.16b, v17.16b | |||||
| sshr v1.4s, v1.4s, #31 | sshr v1.4s, v1.4s, #31 | ||||
| sqadd v17.4s, v17.4s, v1.4s | sqadd v17.4s, v17.4s, v1.4s | ||||
| srshl v17.4s, v17.4s, v4.4s | srshl v17.4s, v17.4s, v4.4s | ||||
| add v2.16b, v4.16b, v18.16b | |||||
| and v2.16b, v4.16b, v18.16b | |||||
| sshr v2.4s, v2.4s, #31 | sshr v2.4s, v2.4s, #31 | ||||
| sqadd v18.4s, v18.4s, v2.4s | sqadd v18.4s, v18.4s, v2.4s | ||||
| srshl v18.4s, v18.4s, v4.4s | srshl v18.4s, v18.4s, v4.4s | ||||
| add v3.16b, v4.16b, v19.16b | |||||
| and v3.16b, v4.16b, v19.16b | |||||
| sshr v3.4s, v3.4s, #31 | sshr v3.4s, v3.4s, #31 | ||||
| sqadd v19.4s, v19.4s, v3.4s | sqadd v19.4s, v19.4s, v3.4s | ||||
| srshl v19.4s, v19.4s, v4.4s | srshl v19.4s, v19.4s, v4.4s | ||||
| add v0.16b, v4.16b, v20.16b | |||||
| and v0.16b, v4.16b, v20.16b | |||||
| sshr v0.4s, v0.4s, #31 | sshr v0.4s, v0.4s, #31 | ||||
| sqadd v20.4s, v20.4s, v0.4s | sqadd v20.4s, v20.4s, v0.4s | ||||
| srshl v20.4s, v20.4s, v4.4s | srshl v20.4s, v20.4s, v4.4s | ||||
| add v0.16b, v4.16b, v21.16b | |||||
| and v0.16b, v4.16b, v21.16b | |||||
| sshr v1.4s, v1.4s, #31 | sshr v1.4s, v1.4s, #31 | ||||
| sqadd v21.4s, v21.4s, v1.4s | sqadd v21.4s, v21.4s, v1.4s | ||||
| srshl v21.4s, v21.4s, v4.4s | srshl v21.4s, v21.4s, v4.4s | ||||
| add v2.16b, v4.16b, v22.16b | |||||
| and v2.16b, v4.16b, v22.16b | |||||
| sshr v2.4s, v2.4s, #31 | sshr v2.4s, v2.4s, #31 | ||||
| sqadd v22.4s, v22.4s, v2.4s | sqadd v22.4s, v22.4s, v2.4s | ||||
| srshl v10.4s, v10.4s, v4.4s | |||||
| add v3.16b, v4.16b, v23.16b | |||||
| srshl v22.4s, v22.4s, v4.4s | |||||
| and v3.16b, v4.16b, v23.16b | |||||
| sshr v3.4s, v3.4s, #31 | sshr v3.4s, v3.4s, #31 | ||||
| sqadd v23.4s, v23.4s, v3.4s | sqadd v23.4s, v23.4s, v3.4s | ||||
| srshl v23.4s, v23.4s, v4.4s | srshl v23.4s, v23.4s, v4.4s | ||||
| add v0.16b, v4.16b, v24.16b | |||||
| and v0.16b, v4.16b, v24.16b | |||||
| sshr v0.4s, v0.4s, #31 | sshr v0.4s, v0.4s, #31 | ||||
| sqadd v24.4s, v24.4s, v0.4s | sqadd v24.4s, v24.4s, v0.4s | ||||
| srshl v24.4s, v24.4s, v4.4s | srshl v24.4s, v24.4s, v4.4s | ||||
| add v0.16b, v4.16b, v25.16b | |||||
| and v0.16b, v4.16b, v25.16b | |||||
| sshr v1.4s, v1.4s, #31 | sshr v1.4s, v1.4s, #31 | ||||
| sqadd v25.4s, v25.4s, v1.4s | sqadd v25.4s, v25.4s, v1.4s | ||||
| srshl v25.4s, v25.4s, v4.4s | srshl v25.4s, v25.4s, v4.4s | ||||
| add v2.16b, v4.16b, v26.16b | |||||
| and v2.16b, v4.16b, v26.16b | |||||
| sshr v2.4s, v2.4s, #31 | sshr v2.4s, v2.4s, #31 | ||||
| sqadd v26.4s, v26.4s, v2.4s | sqadd v26.4s, v26.4s, v2.4s | ||||
| srshl v26.4s, v26.4s, v4.4s | srshl v26.4s, v26.4s, v4.4s | ||||
| add v3.16b, v4.16b, v27.16b | |||||
| and v3.16b, v4.16b, v27.16b | |||||
| sshr v3.4s, v3.4s, #31 | sshr v3.4s, v3.4s, #31 | ||||
| sqadd v27.4s, v27.4s, v3.4s | sqadd v27.4s, v27.4s, v3.4s | ||||
| srshl v27.4s, v27.4s, v4.4s | srshl v27.4s, v27.4s, v4.4s | ||||
| add v0.16b, v4.16b, v28.16b | |||||
| and v0.16b, v4.16b, v28.16b | |||||
| sshr v0.4s, v0.4s, #31 | sshr v0.4s, v0.4s, #31 | ||||
| sqadd v28.4s, v28.4s, v0.4s | sqadd v28.4s, v28.4s, v0.4s | ||||
| srshl v28.4s, v28.4s, v4.4s | srshl v28.4s, v28.4s, v4.4s | ||||
| add v0.16b, v4.16b, v29.16b | |||||
| and v0.16b, v4.16b, v29.16b | |||||
| sshr v1.4s, v1.4s, #31 | sshr v1.4s, v1.4s, #31 | ||||
| sqadd v29.4s, v29.4s, v1.4s | sqadd v29.4s, v29.4s, v1.4s | ||||
| srshl v29.4s, v29.4s, v4.4s | srshl v29.4s, v29.4s, v4.4s | ||||
| add v2.16b, v4.16b, v30.16b | |||||
| and v2.16b, v4.16b, v30.16b | |||||
| sshr v2.4s, v2.4s, #31 | sshr v2.4s, v2.4s, #31 | ||||
| sqadd v30.4s, v30.4s, v2.4s | sqadd v30.4s, v30.4s, v2.4s | ||||
| srshl v30.4s, v30.4s, v4.4s | srshl v30.4s, v30.4s, v4.4s | ||||
| add v3.16b, v4.16b, v31.16b | |||||
| and v3.16b, v4.16b, v31.16b | |||||
| sshr v3.4s, v3.4s, #31 | sshr v3.4s, v3.4s, #31 | ||||
| sqadd v31.4s, v31.4s, v3.4s | sqadd v31.4s, v31.4s, v3.4s | ||||
| srshl v31.4s, v31.4s, v4.4s | srshl v31.4s, v31.4s, v4.4s | ||||
| @@ -694,15 +765,24 @@ IndirectGemmInt8_24x4_dp: | |||||
| bne LoopKsize | bne LoopKsize | ||||
| subs x6, x6, #4 | subs x6, x6, #4 | ||||
| cbz x21, NoChannelForward | |||||
| cbz x20, NoSumForward | |||||
| add x15, x15, #16 | |||||
| NoSumForward: | |||||
| add x17, x17, #16 | |||||
| add x18, x18, #16 | |||||
| add x19, x19, #16 | |||||
| NoChannelForward: | |||||
| cbz x3, NoStepFowrard | cbz x3, NoStepFowrard | ||||
| add x3, x3, #16 | add x3, x3, #16 | ||||
| NoStepFowrard: | NoStepFowrard: | ||||
| bgt LoopOc | bgt LoopOc | ||||
| sub sp, sp, #144 | |||||
| sub sp, sp, #176 | |||||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | ||||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | ||||
| ldp x19, x20, [sp], #16 | ldp x19, x20, [sp], #16 | ||||
| ldp x21, x22, [sp], #16 | |||||
| ldp x23, x24, [sp], #16 | |||||
| ret | ret | ||||
| #endif | #endif | ||||
| @@ -16,7 +16,7 @@ | |||||
| #include "nnacl/fp32/common_func.h" | #include "nnacl/fp32/common_func.h" | ||||
| #ifndef __aarch64__ | |||||
| #ifndef ENABLE_ARM64 | |||||
| void MatrixAdd(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride, | void MatrixAdd(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride, | ||||
| size_t row, size_t col) { | size_t row, size_t col) { | ||||
| for (int r = 0; r < row; r++) { | for (int r = 0; r < row; r++) { | ||||
| @@ -40,8 +40,8 @@ void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t * | |||||
| size_t oc4, size_t offset); | size_t oc4, size_t offset); | ||||
| void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize, | void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize, | ||||
| size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min, | size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min, | ||||
| size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before, | |||||
| size_t shift_after); | |||||
| size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before, | |||||
| int32_t *shift_after, size_t asymmetric, size_t per_channel); | |||||
| void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, | void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, | ||||
| size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | ||||
| size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); | size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); | ||||
| @@ -29,14 +29,12 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in | |||||
| int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0]; | int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0]; | ||||
| int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0]; | int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0]; | ||||
| int oc4 = UP_DIV(output_channel, C4NUM); | int oc4 = UP_DIV(output_channel, C4NUM); | ||||
| #ifdef __aarch64__ | |||||
| #ifdef ENABLE_ARM64 | |||||
| size_t asymmetric = conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC; | |||||
| size_t per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; | |||||
| IndirectGemmInt8_4x4(dst, src, weight, bias, UP_DIV(kernel_plane, C4NUM), ic4, output_channel, | IndirectGemmInt8_4x4(dst, src, weight, bias, UP_DIV(kernel_plane, C4NUM), ic4, output_channel, | ||||
| output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier, | output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier, | ||||
| shift_before, shift_after); | |||||
| // #elif defined(ENABLE_ARM32) | |||||
| // IndirectGemmInt8_2x4(dst, src, weight, bias, UP_DIV(kernel_plane, C4NUM), ic4, output_channel, | |||||
| // output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier, | |||||
| // shift_before, shift_after); | |||||
| shift_before, shift_after, asymmetric, per_channel); | |||||
| #else | #else | ||||
| int tile_num = conv_param->tile_num_; | int tile_num = conv_param->tile_num_; | ||||
| int plane_c4 = UP_DIV(kernel_plane, C4NUM); | int plane_c4 = UP_DIV(kernel_plane, C4NUM); | ||||
| @@ -124,8 +122,10 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const | |||||
| int oc4 = UP_DIV(output_channel, C4NUM); | int oc4 = UP_DIV(output_channel, C4NUM); | ||||
| if (gemm_func != NULL) { | if (gemm_func != NULL) { | ||||
| #ifdef __aarch64__ | #ifdef __aarch64__ | ||||
| size_t asymmetric = conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC; | |||||
| size_t per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; | |||||
| gemm_func(dst, src, weight, bias, kernel_plane, ic4, output_channel, output_channel * sizeof(int8_t), input_sum, | gemm_func(dst, src, weight, bias, kernel_plane, ic4, output_channel, output_channel * sizeof(int8_t), input_sum, | ||||
| act_min, act_max, out_zp, out_multiplier, shift_before, shift_after); | |||||
| act_min, act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel); | |||||
| #endif | #endif | ||||
| } else { | } else { | ||||
| int tile_num = conv_param->tile_num_; | int tile_num = conv_param->tile_num_; | ||||
| @@ -28,8 +28,8 @@ | |||||
| typedef void (*GEMM_FUNC)(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t ksize, | typedef void (*GEMM_FUNC)(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t ksize, | ||||
| size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, size_t act_min, | size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, size_t act_min, | ||||
| size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before, | |||||
| size_t shift_after); | |||||
| size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before, | |||||
| int32_t *shift_after, size_t asymmetric, size_t per_channel); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| @@ -22,11 +22,11 @@ extern "C" { | |||||
| extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, | extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, | ||||
| size_t ksize, size_t ic4, size_t output_channel, size_t offset, | size_t ksize, size_t ic4, size_t output_channel, size_t offset, | ||||
| const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, | const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, | ||||
| size_t out_multiplier, size_t shift_before, size_t shift_after); | |||||
| int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after, | |||||
| size_t asymmetric, size_t per_channel); | |||||
| extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, | extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, | ||||
| const int *input_sum, const int *bias); | const int *input_sum, const int *bias); | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -35,9 +35,10 @@ extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, in | |||||
| void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, | void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, | ||||
| size_t ksize, size_t ic4, size_t output_channel, size_t offset, | size_t ksize, size_t ic4, size_t output_channel, size_t offset, | ||||
| const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, | const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, | ||||
| size_t out_multiplier, size_t shift_before, size_t shift_after) { | |||||
| int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after, | |||||
| size_t asymmetric, size_t per_channel) { | |||||
| return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min, | return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min, | ||||
| act_max, out_zp, out_multiplier, shift_before, shift_after); | |||||
| act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel); | |||||
| } | } | ||||
| void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, | void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, | ||||
| @@ -879,8 +879,8 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int | |||||
| const float *src_ptr = src_batch + hw * channel + c; | const float *src_ptr = src_batch + hw * channel + c; | ||||
| float *dst_ptr = dst_batch + c * plane + hw; | float *dst_ptr = dst_batch + c * plane + hw; | ||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| int srcStride = channel * 4; | |||||
| int dstStride = plane * 4; | |||||
| size_t srcStride = channel * sizeof(float); | |||||
| size_t dstStride = plane * sizeof(float); | |||||
| asm volatile( | asm volatile( | ||||
| "mov x10, %[src_ptr]\n" | "mov x10, %[src_ptr]\n" | ||||
| "mov x11, %[dst_ptr]\n" | "mov x11, %[dst_ptr]\n" | ||||