Merge pull request !7715 from yangruoqi713/conv_dwtags/v1.1.0
| @@ -1,168 +0,0 @@ | |||
| #ifdef __aarch64__ | |||
| .text | |||
| .align 5 | |||
| .global ConvDw3x3BorderPixelInt8 | |||
| #ifndef __APPLE__ | |||
| .type ConvDw3x3BorderPixelInt8, %function | |||
| #endif | |||
| // void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, | |||
| // size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, | |||
| // size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max) { | |||
| // x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: width, x6: in_kh_step, x7: in_kw_step, | |||
| // x8: channel, x9: in_zp, x10: out_zp, x11: out_multiplier, x12: left_shift, x13: right_shift | |||
| // x14: acc_min, x15: acc_max | |||
| ConvDw3x3BorderPixelInt8: | |||
| // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to | |||
| // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | |||
| // x19 ~ x29 should be also preserved | |||
| // whereas our coding style do not permit such amount of parameters | |||
| ldr x8, [sp] | |||
| ldrb w9, [sp, #8] | |||
| dup v25.8b, w9 // in_zp | |||
| ldr x9, [sp, #16] | |||
| dup v26.4s, w9 // out_zp | |||
| ldr x9, [sp, #24] | |||
| dup v27.4s, w9 // out_multiplier | |||
| ldr x9, [sp, #32] | |||
| dup v28.4s, w9 // left_shift | |||
| ldr x9, [sp, #40] | |||
| dup v29.4s, w9 // right_shift | |||
| ldr x9, [sp, #48] | |||
| dup v30.4s, w9 // acc_min | |||
| ldr x9, [sp, #56] | |||
| dup v31.4s, w9 // acc_max | |||
| mov x9, #2 | |||
| mul x13, x8, x9 // x8 * 2 | |||
| mov x9, #3 | |||
| mul x14, x13, x9 // x8 * 3 * 2 | |||
| LoopC: | |||
| ld1 {v23.4s}, [x3], #16 | |||
| ld1 {v24.4s}, [x3], #16 | |||
| mov x9, x1 | |||
| mov x10, x2 | |||
| cmp x4, #2 | |||
| blt LoopHW | |||
| LoopH2W2: | |||
| cmp x5, #2 | |||
| blt LoopHW | |||
| ld1 {v0.8b}, [x9], x7 | |||
| ssubl v0.8h, v0.8b, v25.8b | |||
| add x11, x1, x6 | |||
| ld1 {v4.8h}, [x10], x13 // weight | |||
| smlal v23.4s, v0.4h, v4.4h | |||
| smlal2 v24.4s, v0.8h, v4.8h | |||
| add x12, x2, x14 | |||
| ld1 {v1.8b}, [x9], x7 | |||
| ssubl v1.8h, v1.8b, v25.8b | |||
| ld1 {v5.8h}, [x10], x13 | |||
| smlal v23.4s, v1.4h, v5.4h | |||
| smlal2 v24.4s, v1.8h, v5.8h | |||
| add x15, x11, x6 | |||
| ld1 {v2.8b}, [x11], x7 | |||
| ssubl v2.8h, v2.8b, v25.8b | |||
| add x16, x12, x14 | |||
| ld1 {v6.8h}, [x12], x13 | |||
| smlal v23.4s, v2.4h, v6.4h | |||
| smlal2 v24.4s, v2.8h, v6.8h | |||
| ld1 {v3.8b}, [x11], x7 | |||
| ssubl v3.8h, v3.8b, v25.8b | |||
| ld1 {v7.8h}, [x12], x13 | |||
| smlal v23.4s, v3.4h, v7.4h | |||
| smlal2 v24.4s, v3.8h, v7.8h | |||
| cmp x5, #3 | |||
| beq LoopH2W3 | |||
| cmp x4, #3 | |||
| beq LoopH3W2 | |||
| b Post | |||
| LoopH2W3: | |||
| ld1 {v16.8b}, [x9], x7 | |||
| ssubl v16.8h, v16.8b, v25.8b | |||
| ld1 {v17.8h}, [x10], x13 | |||
| smlal v23.4s, v16.4h, v17.4h | |||
| smlal2 v24.4s, v16.8h, v17.8h | |||
| ld1 {v18.8b}, [x11], x7 | |||
| ssubl v18.8h, v18.8b, v25.8b | |||
| ld1 {v19.8h}, [x12], x13 | |||
| smlal v23.4s, v18.4h, v19.4h | |||
| smlal2 v24.4s, v18.8h, v19.8h | |||
| b Post | |||
| LoopH3W2: | |||
| ld1 {v16.8b}, [x15], x7 | |||
| ssubl v16.8h, v16.8b, v25.8b | |||
| ld1 {v17.8h}, [x16], x13 | |||
| smlal v23.4s, v16.4h, v17.4h | |||
| smlal2 v24.4s, v16.8h, v17.8h | |||
| ld1 {v18.8b}, [x15], x7 | |||
| ssubl v18.8h, v18.8b, v25.8b | |||
| ld1 {v19.8h}, [x16], x13 | |||
| smlal v23.4s, v18.4h, v19.4h | |||
| smlal2 v24.4s, v18.8h, v19.8h | |||
| b Post | |||
| LoopHW: | |||
| mov x9, x1 | |||
| mov x10, x2 | |||
| mov x17, x4 // height | |||
| LoopH: | |||
| mov x11, x9 | |||
| mov x12, x10 | |||
| mov x18, x5 // width | |||
| LoopW: | |||
| ld1 {v0.8b}, [x11], x7 | |||
| ssubl v1.8h, v0.8b, v25.8b | |||
| ld1 {v2.8h}, [x12], x13 // weight | |||
| smlal v23.4s, v1.4h, v2.4h | |||
| smlal2 v24.4s, v1.8h, v2.8h | |||
| subs x18, x18, #1 | |||
| bne LoopW | |||
| subs x17, x17, #1 | |||
| add x9, x9, x6 | |||
| add x10, x10, x14 | |||
| bne LoopH | |||
| Post: | |||
| sqshl v23.4s, v23.4s, v28.4s | |||
| sqshl v24.4s, v24.4s, v28.4s | |||
| sqrdmulh v23.4s, v23.4s, v27.4s | |||
| sqrdmulh v24.4s, v24.4s, v27.4s | |||
| and v12.16b, v29.16b, v23.16b | |||
| sshr v12.4s, v12.4s, #31 | |||
| sqadd v23.4s, v23.4s, v12.4s | |||
| srshl v23.4s, v23.4s, v29.4s | |||
| and v11.16b, v29.16b, v24.16b | |||
| sshr v11.4s, v11.4s, #31 | |||
| sqadd v24.4s, v24.4s, v11.4s | |||
| srshl v24.4s, v24.4s, v29.4s | |||
| add v23.4s, v23.4s, v26.4s | |||
| add v24.4s, v24.4s, v26.4s | |||
| smax v23.4s, v23.4s, v30.4s | |||
| smax v24.4s, v24.4s, v30.4s | |||
| smin v23.4s, v23.4s, v31.4s | |||
| smin v24.4s, v24.4s, v31.4s | |||
| sqxtn v23.4h, v23.4s | |||
| sqxtn v24.4h, v24.4s | |||
| sqxtn v23.8b, v23.8h | |||
| sqxtn v24.8b, v24.8h | |||
| st1 {v23.s}[0], [x0], #4 | |||
| st1 {v24.s}[0], [x0], #4 | |||
| add x1, x1, #8 | |||
| add x2, x2, #16 | |||
| sub x8, x8, #8 | |||
| cmp x8, #8 | |||
| bge LoopC | |||
| ret | |||
| #endif | |||
| @@ -0,0 +1,168 @@ | |||
| #ifdef __aarch64__ | |||
| .text | |||
| .align 5 | |||
| .global ConvDw3x3Int8Corner | |||
| #ifndef __APPLE__ | |||
| .type ConvDw3x3Int8Corner, %function | |||
| #endif | |||
| // void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, | |||
| // size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier, | |||
| // size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max) | |||
| // x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, | |||
| // x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift | |||
| // x11: acc_min, x13: acc_max | |||
| ConvDw3x3Int8Corner: | |||
| // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to | |||
| // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | |||
| // x19 ~ x29 should be also preserved | |||
| // whereas our coding style do not permit such amount of parameters | |||
| dup v25.8b, w7 // in_zp | |||
| ldr x9, [sp] | |||
| dup v26.4s, w9 // out_zp | |||
| ldr x9, [sp, #8] | |||
| dup v27.4s, w9 // out_multiplier | |||
| ldr x9, [sp, #16] | |||
| dup v28.4s, w9 // left_shift | |||
| ldr x9, [sp, #24] | |||
| dup v29.4s, w9 // right_shift | |||
| ldr x9, [sp, #32] | |||
| dup v30.4s, w9 // acc_min | |||
| ldr x9, [sp, #40] | |||
| dup v31.4s, w9 // acc_max | |||
| mov x9, #2 | |||
| mul x13, x6, x9 // x6 * 2 | |||
| mov x9, #3 | |||
| mul x14, x13, x9 // x6 * 3 * 2 | |||
| ld1 {v23.4s}, [x3], #16 | |||
| ld1 {v24.4s}, [x3], #16 | |||
| mov x9, x1 | |||
| mov x10, x2 | |||
| ld1 {v0.8b}, [x9], x5 | |||
| ssubl v0.8h, v0.8b, v25.8b | |||
| add x11, x1, x4 | |||
| ld1 {v4.8h}, [x10], x13 // weight | |||
| add x12, x2, x14 | |||
| ld1 {v1.8b}, [x9], x5 | |||
| ssubl v1.8h, v1.8b, v25.8b | |||
| ld1 {v5.8h}, [x10], x13 | |||
| ld1 {v2.8b}, [x11], x5 | |||
| ssubl v2.8h, v2.8b, v25.8b | |||
| ld1 {v6.8h}, [x12], x13 | |||
| ld1 {v3.8b}, [x11], x5 | |||
| ssubl v3.8h, v3.8b, v25.8b | |||
| ld1 {v7.8h}, [x12], x13 | |||
| cmp x6, #8 | |||
| ble LoopC8Post | |||
| LoopC8: | |||
| add x1, x1, #8 | |||
| add x2, x2, #16 | |||
| smlal v23.4s, v0.4h, v4.4h | |||
| smlal2 v24.4s, v0.8h, v4.8h | |||
| mov x9, x1 | |||
| mov x10, x2 | |||
| ld1 {v0.8b}, [x9], x5 | |||
| ssubl v0.8h, v0.8b, v25.8b | |||
| ld1 {v4.8h}, [x10], x13 // weight | |||
| add x11, x1, x4 | |||
| smlal v23.4s, v1.4h, v5.4h | |||
| smlal2 v24.4s, v1.8h, v5.8h | |||
| add x12, x2, x14 | |||
| ld1 {v1.8b}, [x9], x5 | |||
| ssubl v1.8h, v1.8b, v25.8b | |||
| smlal v23.4s, v2.4h, v6.4h | |||
| ld1 {v5.8h}, [x10], x13 | |||
| smlal2 v24.4s, v2.8h, v6.8h | |||
| ld1 {v2.8b}, [x11], x5 | |||
| ssubl v2.8h, v2.8b, v25.8b | |||
| smlal v23.4s, v3.4h, v7.4h | |||
| ld1 {v6.8h}, [x12], x13 | |||
| smlal2 v24.4s, v3.8h, v7.8h | |||
| sqshl v23.4s, v23.4s, v28.4s | |||
| sqshl v24.4s, v24.4s, v28.4s | |||
| sqrdmulh v23.4s, v23.4s, v27.4s | |||
| sqrdmulh v24.4s, v24.4s, v27.4s | |||
| and v21.16b, v29.16b, v23.16b | |||
| sshr v21.4s, v21.4s, #31 | |||
| sqadd v23.4s, v23.4s, v21.4s | |||
| srshl v23.4s, v23.4s, v29.4s | |||
| and v22.16b, v29.16b, v24.16b | |||
| sshr v22.4s, v22.4s, #31 | |||
| sqadd v24.4s, v24.4s, v22.4s | |||
| srshl v24.4s, v24.4s, v29.4s | |||
| ld1 {v3.8b}, [x11], x5 | |||
| ssubl v3.8h, v3.8b, v25.8b | |||
| ld1 {v7.8h}, [x12], x13 | |||
| add v23.4s, v23.4s, v26.4s | |||
| add v24.4s, v24.4s, v26.4s | |||
| smax v23.4s, v23.4s, v30.4s | |||
| smax v24.4s, v24.4s, v30.4s | |||
| smin v23.4s, v23.4s, v31.4s | |||
| smin v24.4s, v24.4s, v31.4s | |||
| sqxtn v23.4h, v23.4s | |||
| sqxtn v24.4h, v24.4s | |||
| sqxtn v23.8b, v23.8h | |||
| sqxtn v24.8b, v24.8h | |||
| st1 {v23.s}[0], [x0], #4 | |||
| st1 {v24.s}[0], [x0], #4 | |||
| ld1 {v23.4s}, [x3], #16 | |||
| ld1 {v24.4s}, [x3], #16 | |||
| sub x6, x6, #8 | |||
| cmp x6, #8 | |||
| bgt LoopC8 | |||
| LoopC8Post: | |||
| smlal v23.4s, v0.4h, v4.4h | |||
| smlal2 v24.4s, v0.8h, v4.8h | |||
| smlal v23.4s, v1.4h, v5.4h | |||
| smlal2 v24.4s, v1.8h, v5.8h | |||
| smlal v23.4s, v2.4h, v6.4h | |||
| smlal2 v24.4s, v2.8h, v6.8h | |||
| smlal v23.4s, v3.4h, v7.4h | |||
| smlal2 v24.4s, v3.8h, v7.8h | |||
| sqshl v23.4s, v23.4s, v28.4s | |||
| sqshl v24.4s, v24.4s, v28.4s | |||
| sqrdmulh v23.4s, v23.4s, v27.4s | |||
| sqrdmulh v24.4s, v24.4s, v27.4s | |||
| and v21.16b, v29.16b, v23.16b | |||
| sshr v21.4s, v21.4s, #31 | |||
| sqadd v23.4s, v23.4s, v21.4s | |||
| srshl v23.4s, v23.4s, v29.4s | |||
| and v22.16b, v29.16b, v24.16b | |||
| sshr v22.4s, v22.4s, #31 | |||
| sqadd v24.4s, v24.4s, v22.4s | |||
| srshl v24.4s, v24.4s, v29.4s | |||
| add v23.4s, v23.4s, v26.4s | |||
| add v24.4s, v24.4s, v26.4s | |||
| smax v23.4s, v23.4s, v30.4s | |||
| smax v24.4s, v24.4s, v30.4s | |||
| smin v23.4s, v23.4s, v31.4s | |||
| smin v24.4s, v24.4s, v31.4s | |||
| sqxtn v23.4h, v23.4s | |||
| sqxtn v24.4h, v24.4s | |||
| sqxtn v23.8b, v23.8h | |||
| sqxtn v24.8b, v24.8h | |||
| st1 {v23.s}[0], [x0], #4 | |||
| st1 {v24.s}[0], [x0], #4 | |||
| ret | |||
| #endif | |||
| @@ -0,0 +1,196 @@ | |||
| #ifdef __aarch64__ | |||
| .text | |||
| .align 5 | |||
| .global ConvDw3x3Int8Horizontal | |||
| #ifndef __APPLE__ | |||
| .type ConvDw3x3Int8Horizontal, %function | |||
| #endif | |||
| // void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, | |||
| // size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier, | |||
| // size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max) | |||
| // x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, | |||
| // x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift | |||
| // x11: acc_min, x13: acc_max | |||
| ConvDw3x3Int8Horizontal: | |||
| // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to | |||
| // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | |||
| // x19 ~ x29 should be also preserved | |||
| // whereas our coding style do not permit such amount of parameters | |||
| dup v25.8b, w7 // in_zp | |||
| ldr x9, [sp] | |||
| dup v26.4s, w9 // out_zp | |||
| ldr x9, [sp, #8] | |||
| dup v27.4s, w9 // out_multiplier | |||
| ldr x9, [sp, #16] | |||
| dup v28.4s, w9 // left_shift | |||
| ldr x9, [sp, #24] | |||
| dup v29.4s, w9 // right_shift | |||
| ldr x9, [sp, #32] | |||
| dup v30.4s, w9 // acc_min | |||
| ldr x9, [sp, #40] | |||
| dup v31.4s, w9 // acc_max | |||
| mov x9, #2 | |||
| mul x13, x6, x9 // x6 * 2 | |||
| mov x9, #3 | |||
| mul x14, x13, x9 // x6 * 3 * 2 | |||
| ld1 {v23.4s}, [x3], #16 | |||
| ld1 {v24.4s}, [x3], #16 | |||
| mov x9, x1 | |||
| mov x10, x2 | |||
| ld1 {v0.8b}, [x9], x5 | |||
| ssubl v0.8h, v0.8b, v25.8b | |||
| add x11, x1, x4 | |||
| ld1 {v4.8h}, [x10], x13 // weight | |||
| add x12, x2, x14 | |||
| ld1 {v1.8b}, [x9], x5 | |||
| ssubl v1.8h, v1.8b, v25.8b | |||
| ld1 {v5.8h}, [x10], x13 | |||
| add x15, x11, x4 | |||
| ld1 {v2.8b}, [x11], x5 | |||
| ssubl v2.8h, v2.8b, v25.8b | |||
| add x16, x12, x14 | |||
| ld1 {v6.8h}, [x12], x13 | |||
| ld1 {v3.8b}, [x11], x5 | |||
| ssubl v3.8h, v3.8b, v25.8b | |||
| ld1 {v7.8h}, [x12], x13 | |||
| ld1 {v16.8b}, [x15], x5 | |||
| ssubl v16.8h, v16.8b, v25.8b | |||
| ld1 {v18.8h}, [x16], x13 | |||
| ld1 {v17.8b}, [x15], x5 | |||
| ssubl v17.8h, v17.8b, v25.8b | |||
| ld1 {v19.8h}, [x16], x13 | |||
| cmp x6, #8 | |||
| ble LoopC8Post | |||
| LoopC8: | |||
| add x1, x1, #8 | |||
| add x2, x2, #16 | |||
| smlal v23.4s, v0.4h, v4.4h | |||
| smlal2 v24.4s, v0.8h, v4.8h | |||
| mov x9, x1 | |||
| mov x10, x2 | |||
| ld1 {v0.8b}, [x9], x5 | |||
| ssubl v0.8h, v0.8b, v25.8b | |||
| ld1 {v4.8h}, [x10], x13 // weight | |||
| add x11, x1, x4 | |||
| smlal v23.4s, v1.4h, v5.4h | |||
| smlal2 v24.4s, v1.8h, v5.8h | |||
| add x12, x2, x14 | |||
| ld1 {v1.8b}, [x9], x5 | |||
| ssubl v1.8h, v1.8b, v25.8b | |||
| smlal v23.4s, v2.4h, v6.4h | |||
| ld1 {v5.8h}, [x10], x13 | |||
| smlal2 v24.4s, v2.8h, v6.8h | |||
| add x15, x11, x4 | |||
| add x16, x12, x14 | |||
| ld1 {v2.8b}, [x11], x5 | |||
| ssubl v2.8h, v2.8b, v25.8b | |||
| smlal v23.4s, v3.4h, v7.4h | |||
| ld1 {v6.8h}, [x12], x13 | |||
| smlal2 v24.4s, v3.8h, v7.8h | |||
| ld1 {v3.8b}, [x11], x5 | |||
| ssubl v3.8h, v3.8b, v25.8b | |||
| smlal v23.4s, v16.4h, v18.4h | |||
| ld1 {v7.8h}, [x12], x13 | |||
| smlal2 v24.4s, v16.8h, v18.8h | |||
| ld1 {v16.8b}, [x15], x5 | |||
| ssubl v16.8h, v16.8b, v25.8b | |||
| smlal v23.4s, v17.4h, v19.4h | |||
| ld1 {v18.8h}, [x16], x13 | |||
| smlal2 v24.4s, v17.8h, v19.8h | |||
| sqshl v23.4s, v23.4s, v28.4s | |||
| sqshl v24.4s, v24.4s, v28.4s | |||
| sqrdmulh v23.4s, v23.4s, v27.4s | |||
| sqrdmulh v24.4s, v24.4s, v27.4s | |||
| and v21.16b, v29.16b, v23.16b | |||
| sshr v21.4s, v21.4s, #31 | |||
| sqadd v23.4s, v23.4s, v21.4s | |||
| srshl v23.4s, v23.4s, v29.4s | |||
| and v22.16b, v29.16b, v24.16b | |||
| sshr v22.4s, v22.4s, #31 | |||
| sqadd v24.4s, v24.4s, v22.4s | |||
| srshl v24.4s, v24.4s, v29.4s | |||
| ld1 {v17.8b}, [x15], x5 | |||
| ssubl v17.8h, v17.8b, v25.8b | |||
| ld1 {v19.8h}, [x16], x13 | |||
| add v23.4s, v23.4s, v26.4s | |||
| add v24.4s, v24.4s, v26.4s | |||
| smax v23.4s, v23.4s, v30.4s | |||
| smax v24.4s, v24.4s, v30.4s | |||
| smin v23.4s, v23.4s, v31.4s | |||
| smin v24.4s, v24.4s, v31.4s | |||
| sqxtn v23.4h, v23.4s | |||
| sqxtn v24.4h, v24.4s | |||
| sqxtn v23.8b, v23.8h | |||
| sqxtn v24.8b, v24.8h | |||
| st1 {v23.s}[0], [x0], #4 | |||
| st1 {v24.s}[0], [x0], #4 | |||
| ld1 {v23.4s}, [x3], #16 | |||
| ld1 {v24.4s}, [x3], #16 | |||
| sub x6, x6, #8 | |||
| cmp x6, #8 | |||
| bgt LoopC8 | |||
| LoopC8Post: | |||
| smlal v23.4s, v0.4h, v4.4h | |||
| smlal2 v24.4s, v0.8h, v4.8h | |||
| smlal v23.4s, v1.4h, v5.4h | |||
| smlal2 v24.4s, v1.8h, v5.8h | |||
| smlal v23.4s, v2.4h, v6.4h | |||
| smlal2 v24.4s, v2.8h, v6.8h | |||
| smlal v23.4s, v3.4h, v7.4h | |||
| smlal2 v24.4s, v3.8h, v7.8h | |||
| smlal v23.4s, v16.4h, v18.4h | |||
| smlal2 v24.4s, v16.8h, v18.8h | |||
| smlal v23.4s, v17.4h, v19.4h | |||
| smlal2 v24.4s, v17.8h, v19.8h | |||
| sqshl v23.4s, v23.4s, v28.4s | |||
| sqshl v24.4s, v24.4s, v28.4s | |||
| sqrdmulh v23.4s, v23.4s, v27.4s | |||
| sqrdmulh v24.4s, v24.4s, v27.4s | |||
| and v21.16b, v29.16b, v23.16b | |||
| sshr v21.4s, v21.4s, #31 | |||
| sqadd v23.4s, v23.4s, v21.4s | |||
| srshl v23.4s, v23.4s, v29.4s | |||
| and v22.16b, v29.16b, v24.16b | |||
| sshr v22.4s, v22.4s, #31 | |||
| sqadd v24.4s, v24.4s, v22.4s | |||
| srshl v24.4s, v24.4s, v29.4s | |||
| add v23.4s, v23.4s, v26.4s | |||
| add v24.4s, v24.4s, v26.4s | |||
| smax v23.4s, v23.4s, v30.4s | |||
| smax v24.4s, v24.4s, v30.4s | |||
| smin v23.4s, v23.4s, v31.4s | |||
| smin v24.4s, v24.4s, v31.4s | |||
| sqxtn v23.4h, v23.4s | |||
| sqxtn v24.4h, v24.4s | |||
| sqxtn v23.8b, v23.8h | |||
| sqxtn v24.8b, v24.8h | |||
| st1 {v23.s}[0], [x0], #4 | |||
| st1 {v24.s}[0], [x0], #4 | |||
| ret | |||
| #endif | |||
| @@ -0,0 +1,192 @@ | |||
| #ifdef __aarch64__ | |||
| .text | |||
| .align 5 | |||
| .global ConvDw3x3Int8Vertical | |||
| #ifndef __APPLE__ | |||
| .type ConvDw3x3Int8Vertical, %function | |||
| #endif | |||
| // void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, | |||
| // size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier, | |||
| // size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max) | |||
| // x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, | |||
| // x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift | |||
| // x11: acc_min, x13: acc_max | |||
| ConvDw3x3Int8Vertical: | |||
| // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to | |||
| // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | |||
| // x19 ~ x29 should be also preserved | |||
| // whereas our coding style do not permit such amount of parameters | |||
| dup v25.8b, w7 // in_zp | |||
| ldr x9, [sp] | |||
| dup v26.4s, w9 // out_zp | |||
| ldr x9, [sp, #8] | |||
| dup v27.4s, w9 // out_multiplier | |||
| ldr x9, [sp, #16] | |||
| dup v28.4s, w9 // left_shift | |||
| ldr x9, [sp, #24] | |||
| dup v29.4s, w9 // right_shift | |||
| ldr x9, [sp, #32] | |||
| dup v30.4s, w9 // acc_min | |||
| ldr x9, [sp, #40] | |||
| dup v31.4s, w9 // acc_max | |||
| mov x9, #2 | |||
| mul x13, x6, x9 // x6 * 2 | |||
| mov x9, #3 | |||
| mul x14, x13, x9 // x6 * 3 * 2 | |||
| ld1 {v23.4s}, [x3], #16 | |||
| ld1 {v24.4s}, [x3], #16 | |||
| mov x9, x1 | |||
| mov x10, x2 | |||
| ld1 {v0.8b}, [x9], x5 | |||
| ssubl v0.8h, v0.8b, v25.8b | |||
| add x11, x1, x4 | |||
| ld1 {v4.8h}, [x10], x13 // weight | |||
| add x12, x2, x14 | |||
| ld1 {v1.8b}, [x9], x5 | |||
| ssubl v1.8h, v1.8b, v25.8b | |||
| ld1 {v5.8h}, [x10], x13 | |||
| ld1 {v2.8b}, [x11], x5 | |||
| ssubl v2.8h, v2.8b, v25.8b | |||
| ld1 {v6.8h}, [x12], x13 | |||
| ld1 {v3.8b}, [x11], x5 | |||
| ssubl v3.8h, v3.8b, v25.8b | |||
| ld1 {v7.8h}, [x12], x13 | |||
| ld1 {v16.8b}, [x9], x5 | |||
| ssubl v16.8h, v16.8b, v25.8b | |||
| ld1 {v18.8h}, [x10], x13 | |||
| ld1 {v17.8b}, [x11], x5 | |||
| ssubl v17.8h, v17.8b, v25.8b | |||
| ld1 {v19.8h}, [x12], x13 | |||
| cmp x6, #8 | |||
| ble LoopC8Post | |||
| LoopC8: | |||
| add x1, x1, #8 | |||
| add x2, x2, #16 | |||
| smlal v23.4s, v0.4h, v4.4h | |||
| smlal2 v24.4s, v0.8h, v4.8h | |||
| mov x9, x1 | |||
| mov x10, x2 | |||
| ld1 {v0.8b}, [x9], x5 | |||
| ssubl v0.8h, v0.8b, v25.8b | |||
| ld1 {v4.8h}, [x10], x13 // weight | |||
| add x11, x1, x4 | |||
| smlal v23.4s, v1.4h, v5.4h | |||
| smlal2 v24.4s, v1.8h, v5.8h | |||
| add x12, x2, x14 | |||
| ld1 {v1.8b}, [x9], x5 | |||
| ssubl v1.8h, v1.8b, v25.8b | |||
| smlal v23.4s, v2.4h, v6.4h | |||
| ld1 {v5.8h}, [x10], x13 | |||
| smlal2 v24.4s, v2.8h, v6.8h | |||
| ld1 {v2.8b}, [x11], x5 | |||
| ssubl v2.8h, v2.8b, v25.8b | |||
| smlal v23.4s, v3.4h, v7.4h | |||
| ld1 {v6.8h}, [x12], x13 | |||
| smlal2 v24.4s, v3.8h, v7.8h | |||
| ld1 {v3.8b}, [x11], x5 | |||
| ssubl v3.8h, v3.8b, v25.8b | |||
| smlal v23.4s, v16.4h, v18.4h | |||
| ld1 {v7.8h}, [x12], x13 | |||
| smlal2 v24.4s, v16.8h, v18.8h | |||
| ld1 {v16.8b}, [x9], x5 | |||
| ssubl v16.8h, v16.8b, v25.8b | |||
| smlal v23.4s, v17.4h, v19.4h | |||
| ld1 {v18.8h}, [x10], x13 | |||
| smlal2 v24.4s, v17.8h, v19.8h | |||
| sqshl v23.4s, v23.4s, v28.4s | |||
| sqshl v24.4s, v24.4s, v28.4s | |||
| sqrdmulh v23.4s, v23.4s, v27.4s | |||
| sqrdmulh v24.4s, v24.4s, v27.4s | |||
| and v21.16b, v29.16b, v23.16b | |||
| sshr v21.4s, v21.4s, #31 | |||
| sqadd v23.4s, v23.4s, v21.4s | |||
| srshl v23.4s, v23.4s, v29.4s | |||
| and v22.16b, v29.16b, v24.16b | |||
| sshr v22.4s, v22.4s, #31 | |||
| sqadd v24.4s, v24.4s, v22.4s | |||
| srshl v24.4s, v24.4s, v29.4s | |||
| ld1 {v17.8b}, [x11], x5 | |||
| ssubl v17.8h, v17.8b, v25.8b | |||
| ld1 {v19.8h}, [x12], x13 | |||
| add v23.4s, v23.4s, v26.4s | |||
| add v24.4s, v24.4s, v26.4s | |||
| smax v23.4s, v23.4s, v30.4s | |||
| smax v24.4s, v24.4s, v30.4s | |||
| smin v23.4s, v23.4s, v31.4s | |||
| smin v24.4s, v24.4s, v31.4s | |||
| sqxtn v23.4h, v23.4s | |||
| sqxtn v24.4h, v24.4s | |||
| sqxtn v23.8b, v23.8h | |||
| sqxtn v24.8b, v24.8h | |||
| st1 {v23.s}[0], [x0], #4 | |||
| st1 {v24.s}[0], [x0], #4 | |||
| ld1 {v23.4s}, [x3], #16 | |||
| ld1 {v24.4s}, [x3], #16 | |||
| sub x6, x6, #8 | |||
| cmp x6, #8 | |||
| bgt LoopC8 | |||
| LoopC8Post: | |||
| smlal v23.4s, v0.4h, v4.4h | |||
| smlal2 v24.4s, v0.8h, v4.8h | |||
| smlal v23.4s, v1.4h, v5.4h | |||
| smlal2 v24.4s, v1.8h, v5.8h | |||
| smlal v23.4s, v2.4h, v6.4h | |||
| smlal2 v24.4s, v2.8h, v6.8h | |||
| smlal v23.4s, v3.4h, v7.4h | |||
| smlal2 v24.4s, v3.8h, v7.8h | |||
| smlal v23.4s, v16.4h, v18.4h | |||
| smlal2 v24.4s, v16.8h, v18.8h | |||
| smlal v23.4s, v17.4h, v19.4h | |||
| smlal2 v24.4s, v17.8h, v19.8h | |||
| sqshl v23.4s, v23.4s, v28.4s | |||
| sqshl v24.4s, v24.4s, v28.4s | |||
| sqrdmulh v23.4s, v23.4s, v27.4s | |||
| sqrdmulh v24.4s, v24.4s, v27.4s | |||
| and v21.16b, v29.16b, v23.16b | |||
| sshr v21.4s, v21.4s, #31 | |||
| sqadd v23.4s, v23.4s, v21.4s | |||
| srshl v23.4s, v23.4s, v29.4s | |||
| and v22.16b, v29.16b, v24.16b | |||
| sshr v22.4s, v22.4s, #31 | |||
| sqadd v24.4s, v24.4s, v22.4s | |||
| srshl v24.4s, v24.4s, v29.4s | |||
| add v23.4s, v23.4s, v26.4s | |||
| add v24.4s, v24.4s, v26.4s | |||
| smax v23.4s, v23.4s, v30.4s | |||
| smax v24.4s, v24.4s, v30.4s | |||
| smin v23.4s, v23.4s, v31.4s | |||
| smin v24.4s, v24.4s, v31.4s | |||
| sqxtn v23.4h, v23.4s | |||
| sqxtn v24.4h, v24.4s | |||
| sqxtn v23.8b, v23.8h | |||
| sqxtn v24.8b, v24.8h | |||
| st1 {v23.s}[0], [x0], #4 | |||
| st1 {v24.s}[0], [x0], #4 | |||
| ret | |||
| #endif | |||
| @@ -47,10 +47,6 @@ void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, con | |||
| size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp, | |||
| int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, | |||
| int32_t *acc_min, int32_t *acc_max); | |||
| void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, | |||
| size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, | |||
| size_t out_zp, size_t out_multiplier, size_t left_shift, size_t right_shift, | |||
| size_t acc_min, size_t acc_max); | |||
| #endif | |||
| #ifdef ENABLE_ARM32 | |||
| @@ -71,6 +67,21 @@ void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *wei | |||
| void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, | |||
| size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | |||
| size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); | |||
| void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, | |||
| int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, | |||
| int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min, | |||
| int32_t acc_max); | |||
| void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, | |||
| size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier, | |||
| size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max); | |||
| void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, | |||
| size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, | |||
| size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min, | |||
| size_t acc_max); | |||
| void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, | |||
| size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, | |||
| size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min, | |||
| size_t acc_max); | |||
| #endif | |||
| #ifdef __cplusplus | |||
| } | |||
| @@ -232,29 +232,31 @@ void ConvDw3x3Int8Row(int8_t *output, int8_t *buffer, const int8_t *input, const | |||
| int ih_offset = 64 * block_input_w; | |||
| int w = start_w; | |||
| for (; w <= end_w - block_output_w; w += block_output_w) { | |||
| int8_t *output_ptr = output; | |||
| const int8_t *input_ptr = input; | |||
| const int16_t *weight_ptr = weight; | |||
| const int32_t *bias_ptr = bias; | |||
| int c = 0; | |||
| for (; c <= conv_param->output_channel_ - 64; c += 64) { | |||
| InitInputBuffer(buffer, input_ptr, conv_param, block_input_h, block_input_w); | |||
| ConvDw3x3Int8Block(output_ptr, buffer, weight_ptr, bias_ptr, 0, 64, 64, ih_offset, conv_param->input_channel_, | |||
| block_output_h, block_output_w, in_zp, out_zp, out_multiplier, left_shift, right_shift, | |||
| acc_min, acc_max, conv_param->stride_h_); | |||
| output_ptr += 64; | |||
| input_ptr += 64; | |||
| weight_ptr += 64; | |||
| bias_ptr += 64; | |||
| if (conv_param->output_channel_ > 64 || (conv_param->output_channel_ < 64 && conv_param->input_w_ > 150)) { | |||
| for (; w <= end_w - block_output_w; w += block_output_w) { | |||
| int8_t *output_ptr = output; | |||
| const int8_t *input_ptr = input; | |||
| const int16_t *weight_ptr = weight; | |||
| const int32_t *bias_ptr = bias; | |||
| int c = 0; | |||
| for (; c <= conv_param->output_channel_ - 64; c += 64) { | |||
| InitInputBuffer(buffer, input_ptr, conv_param, block_input_h, block_input_w); | |||
| ConvDw3x3Int8Block(output_ptr, buffer, weight_ptr, bias_ptr, 0, 64, 64, ih_offset, conv_param->input_channel_, | |||
| block_output_h, block_output_w, in_zp, out_zp, out_multiplier, left_shift, right_shift, | |||
| acc_min, acc_max, conv_param->stride_h_); | |||
| output_ptr += 64; | |||
| input_ptr += 64; | |||
| weight_ptr += 64; | |||
| bias_ptr += 64; | |||
| } | |||
| // left channel | |||
| ConvDw3x3Int8Block(output_ptr, input_ptr, weight_ptr, bias_ptr, c, conv_param->input_channel_, | |||
| conv_param->input_channel_, conv_param->input_w_ * conv_param->input_channel_, | |||
| conv_param->input_channel_, block_output_h, block_output_w, in_zp, out_zp, out_multiplier, | |||
| left_shift, right_shift, acc_min, acc_max, conv_param->stride_h_); | |||
| output += block_output_w * conv_param->input_channel_; | |||
| input += conv_param->stride_w_ * block_output_w * conv_param->input_channel_; | |||
| } | |||
| // left channel | |||
| ConvDw3x3Int8Block(output_ptr, input_ptr, weight_ptr, bias_ptr, c, conv_param->input_channel_, | |||
| conv_param->input_channel_, conv_param->input_w_ * conv_param->input_channel_, | |||
| conv_param->input_channel_, block_output_h, block_output_w, in_zp, out_zp, out_multiplier, | |||
| left_shift, right_shift, acc_min, acc_max, conv_param->stride_h_); | |||
| output += block_output_w * conv_param->input_channel_; | |||
| input += conv_param->stride_w_ * block_output_w * conv_param->input_channel_; | |||
| } | |||
| // left width | |||
| int left_width = end_w - w; | |||
| @@ -300,8 +302,7 @@ void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data | |||
| } | |||
| } | |||
| #ifndef ENABLE_ARM | |||
| void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, | |||
| void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, | |||
| int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, | |||
| int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) { | |||
| for (int c = 0; c < channel; c += 8) { | |||
| @@ -337,9 +338,30 @@ void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *wei | |||
| } | |||
| } | |||
| } | |||
| #ifndef ENABLE_ARM64 | |||
| void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step, | |||
| int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, int out_multiplier, int left_shift, | |||
| int right_shift, int32_t acc_min, int32_t acc_max) { | |||
| ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 2, 2, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier, | |||
| left_shift, right_shift, acc_min, acc_max); | |||
| } | |||
| void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step, | |||
| int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, int out_multiplier, | |||
| int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) { | |||
| ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 2, 3, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier, | |||
| left_shift, right_shift, acc_min, acc_max); | |||
| } | |||
| void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step, | |||
| int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, int out_multiplier, | |||
| int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) { | |||
| ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 3, 2, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier, | |||
| left_shift, right_shift, acc_min, acc_max); | |||
| } | |||
| #endif | |||
| void ConvDw3x3BorderInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int top, | |||
| void ConvDw3x3Int8Border(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int top, | |||
| int bottom, int left, int right, const ConvParameter *conv_param, | |||
| const SlidingWindowParam *sliding, int8_t in_zp, int32_t out_zp, int out_multiplier, | |||
| int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) { | |||
| @@ -361,7 +383,7 @@ void ConvDw3x3BorderInt8(int8_t *dst, const int8_t *src, const int16_t *weight, | |||
| const int16_t *weight_kernel = | |||
| weight + (start_kh * conv_param->kernel_w_ + start_kw) * conv_param->input_channel_; | |||
| ConvDw3x3BorderPixelInt8(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, | |||
| ConvDw3x3Int8BorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, | |||
| sliding->in_kh_step_, sliding->in_kw_step_, conv_param->input_channel_, in_zp, out_zp, | |||
| out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||
| @@ -371,7 +393,7 @@ void ConvDw3x3BorderInt8(int8_t *dst, const int8_t *src, const int16_t *weight, | |||
| } // height loop | |||
| } | |||
| void ConvDw3x3PadInt8(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, | |||
| void ConvDw3x3Int8Pad(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, | |||
| const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding) { | |||
| int out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0]; | |||
| int left_shift = conv_param->conv_quant_arg_.left_shift_[0]; | |||
| @@ -380,17 +402,70 @@ void ConvDw3x3PadInt8(int8_t *output_data, const int8_t *input_data, const int16 | |||
| int out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; | |||
| int acc_min = conv_param->conv_quant_arg_.out_act_min_[0]; | |||
| int acc_max = conv_param->conv_quant_arg_.out_act_max_[0]; | |||
| ConvDw3x3BorderInt8(output_data, input_data, weight_data, bias_data, 0, sliding->top_, 0, conv_param->output_w_, | |||
| conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||
| ConvDw3x3BorderInt8(output_data, input_data, weight_data, bias_data, sliding->bottom_, conv_param->output_h_, 0, | |||
| conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, | |||
| right_shift, acc_min, acc_max); | |||
| ConvDw3x3BorderInt8(output_data, input_data, weight_data, bias_data, sliding->top_, sliding->bottom_, 0, | |||
| sliding->left_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, | |||
| acc_min, acc_max); | |||
| ConvDw3x3BorderInt8(output_data, input_data, weight_data, bias_data, sliding->top_, sliding->bottom_, sliding->right_, | |||
| conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, | |||
| right_shift, acc_min, acc_max); | |||
| int input_row_size = conv_param->input_w_ * conv_param->input_channel_; | |||
| int weight_row_size = conv_param->kernel_w_ * conv_param->input_channel_; | |||
| int output_row_size = conv_param->output_w_ * conv_param->output_channel_; | |||
| int in_kh_step = sliding->in_kh_step_; | |||
| int in_kw_step = sliding->in_kw_step_; | |||
| // top | |||
| const int8_t *input = input_data; | |||
| const int16_t *weight = weight_data + weight_row_size + conv_param->input_channel_; | |||
| int8_t *output = output_data; | |||
| ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||
| input += (conv_param->stride_w_ - 1) * conv_param->input_channel_; | |||
| weight = weight_data + weight_row_size; | |||
| output += conv_param->output_channel_; | |||
| for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { | |||
| ConvDw3x3Int8Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||
| input += conv_param->stride_w_ * conv_param->input_channel_; | |||
| output += conv_param->output_channel_; | |||
| } | |||
| ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||
| // left | |||
| input = input_data + (conv_param->stride_h_ - 1) * input_row_size; | |||
| weight = weight_data + conv_param->input_channel_; | |||
| output = output_data + output_row_size; | |||
| for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { | |||
| ConvDw3x3Int8Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||
| input += conv_param->stride_h_ * input_row_size; | |||
| output += output_row_size; | |||
| } | |||
| // right | |||
| input = | |||
| input_data + (conv_param->input_w_ - 2) * conv_param->input_channel_ + (conv_param->stride_h_ - 1) * input_row_size; | |||
| weight = weight_data; | |||
| output = output_data + output_row_size + (conv_param->output_w_ - 1) * conv_param->output_channel_; | |||
| for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { | |||
| ConvDw3x3Int8Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||
| input += conv_param->stride_h_ * input_row_size; | |||
| output += output_row_size; | |||
| } | |||
| // bottom | |||
| input = input_data + (conv_param->input_h_ - 2) * input_row_size; | |||
| weight = weight_data + conv_param->input_channel_; | |||
| output = output_data + (conv_param->output_h_ - 1) * output_row_size; | |||
| ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||
| input += conv_param->stride_w_ == 1 ? 0 : conv_param->input_channel_; | |||
| weight = weight_data; | |||
| output += conv_param->output_channel_; | |||
| for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { | |||
| ConvDw3x3Int8Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||
| input += conv_param->stride_w_ * conv_param->input_channel_; | |||
| output += conv_param->output_channel_; | |||
| } | |||
| ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, | |||
| out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||
| } | |||
| /*conv depthwise 3x3 int8 end*/ | |||
| @@ -29,7 +29,7 @@ bool CheckIfUse3X3(const ConvParameter *conv_param, int channel); | |||
| void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data, | |||
| const int32_t *bias_data, const ConvParameter *conv_param, int task_id); | |||
| void ConvDw3x3PadInt8(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, | |||
| void ConvDw3x3Int8Pad(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, | |||
| const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding); | |||
| void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data, const int16_t *weight_data, | |||
| @@ -44,13 +44,6 @@ void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *in | |||
| const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, | |||
| int task_id); | |||
| #ifdef ENABLE_ARM64 | |||
| void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, | |||
| int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, | |||
| int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min, | |||
| int32_t acc_max); | |||
| #endif | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -164,7 +164,7 @@ int ConvolutionDepthwise3x3Int8CPUKernel::Run() { | |||
| if (sliding_->top_ > 0 || sliding_->bottom_ < conv_param_->output_h_ || sliding_->left_ > 0 || | |||
| sliding_->right_ < conv_param_->output_w_) { | |||
| ConvDw3x3PadInt8(output_ptr_, input_ptr_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), conv_param_, | |||
| ConvDw3x3Int8Pad(output_ptr_, input_ptr_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), conv_param_, | |||
| sliding_); | |||
| } | |||
| ret = ParallelLaunch(this->context_->thread_pool_, ConvDw3x3Int8Run, this, conv_param_->thread_num_); | |||