Merge pull request !6038 from yangruoqi713/act_per_channeltags/v1.0.0
| @@ -7,13 +7,15 @@ | |||
| .type ConvDwInt8Center, %function | |||
| #endif | |||
| // void ConvDwInt8Center(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, size_t height, size_t width, | |||
| // size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, | |||
| // size_t in_kh_step, size_t in_kw_step, int out_multiplier, int left_shift, | |||
| // int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max); | |||
| // void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, | |||
| // size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, | |||
| // size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp, | |||
| // int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, | |||
| // int32_t *acc_min, int32_t *acc_max) | |||
| // x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: weight, x6: kernel_h, x7: kernel_w, | |||
| // x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step | |||
| // x14: out_multiplier, #56: left_shift, #64: right_shift, #72:out_zp, #80: acc_min, #88: acc_max | |||
| // x14: in_zp, #56: out_zp, #64: out_multiplier, #72:left_shift, #80: right_shift, #88: acc_min, #96: acc_max | |||
| ConvDwInt8Center: | |||
| // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to | |||
| // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | |||
| @@ -33,489 +35,174 @@ ConvDwInt8Center: | |||
| ldr x12, [sp, #32] | |||
| ldr x13, [sp, #40] | |||
| ldr w14, [sp, #56] | |||
| dup v26.4s, w14 | |||
| ldr x14, [sp, #48] // input_zp | |||
| ld1 {v19.8b}, [x14], #8 | |||
| ldr x15, [sp, #56] // output_zp | |||
| ld1 {v20.4s}, [x15], #16 | |||
| ld1 {v21.4s}, [x15], #16 | |||
| ldr x15, [sp, #48] | |||
| dup v27.4s, w15 | |||
| ldr x16, [sp, #64] // out_multiplier | |||
| ld1 {v22.4s}, [x16], #16 | |||
| ld1 {v23.4s}, [x16], #16 | |||
| ldr w16, [sp, #64] | |||
| dup v28.4s, w16 | |||
| ldr x17, [sp, #72] // left_shift | |||
| ld1 {v24.4s}, [x17], #16 | |||
| ld1 {v25.4s}, [x17], #16 | |||
| ldr w17, [sp, #72] | |||
| dup v29.4s, w17 | |||
| ldr w18, [sp, #80] | |||
| dup v30.4s, w18 | |||
| ldr x18, [sp, #80] // right shift | |||
| ld1 {v26.4s}, [x18], #16 | |||
| ld1 {v27.4s}, [x18], #16 | |||
| ldr w19, [sp, #88] | |||
| dup v31.4s, w19 | |||
| ldr x19, [sp, #88] // acc_min | |||
| ld1 {v28.4s}, [x19], #16 | |||
| ld1 {v29.4s}, [x19], #16 | |||
| ld1 {v24.4s}, [x3] | |||
| ldr x20, [sp, #96] // acc_max | |||
| ld1 {v30.4s}, [x20], #16 | |||
| ld1 {v31.4s}, [x20], #16 | |||
| ld1 {v17.4s}, [x3], #16 | |||
| ld1 {v18.4s}, [x3], #16 | |||
| LoopH: | |||
| mov x23, x1 | |||
| mov x24, x5 | |||
| mov x3, x0 | |||
| cmp x24, #8 | |||
| blt LoopW | |||
| cmp x24, #16 | |||
| blt LoopW8 | |||
| LoopW16: | |||
| mov x19, #16 | |||
| LoopW4: | |||
| mov x19, #4 | |||
| mul x19, x19, x11 | |||
| mov x25, #4 | |||
| mul x25, x25, x9 | |||
| mov x16, x23 | |||
| mov x17, x2 | |||
| mov x20, x6 | |||
| mov v0.16b, v24.16b | |||
| mov v1.16b, v24.16b | |||
| mov v2.16b, v24.16b | |||
| mov v3.16b, v24.16b | |||
| mov v4.16b, v24.16b | |||
| mov v5.16b, v24.16b | |||
| mov v6.16b, v24.16b | |||
| mov v7.16b, v24.16b | |||
| mov v8.16b, v24.16b | |||
| mov v9.16b, v24.16b | |||
| mov v10.16b, v24.16b | |||
| mov v11.16b, v24.16b | |||
| mov v12.16b, v24.16b | |||
| mov v13.16b, v24.16b | |||
| mov v14.16b, v24.16b | |||
| mov v15.16b, v24.16b | |||
| LoopKh16: | |||
| mov v0.16b, v17.16b | |||
| mov v1.16b, v18.16b | |||
| mov v2.16b, v17.16b | |||
| mov v3.16b, v18.16b | |||
| mov v4.16b, v17.16b | |||
| mov v5.16b, v18.16b | |||
| mov v6.16b, v17.16b | |||
| mov v7.16b, v18.16b | |||
| LoopKh4: | |||
| mov x18, x7 | |||
| mov x21, x16 | |||
| LoopKw16: | |||
| LoopKw4: | |||
| mov x22, x21 | |||
| ld1 {v25.4h}, [x17], #8 | |||
| ld1 {v16.4h}, [x22], x11 | |||
| ld1 {v17.4h}, [x22], x11 | |||
| smlal v0.4s, v16.4h, v25.4h | |||
| smlal v1.4s, v17.4h, v25.4h | |||
| ld1 {v18.4h}, [x22], x11 | |||
| ld1 {v19.4h}, [x22], x11 | |||
| smlal v2.4s, v18.4h, v25.4h | |||
| smlal v3.4s, v19.4h, v25.4h | |||
| ld1 {v20.4h}, [x22], x11 | |||
| ld1 {v21.4h}, [x22], x11 | |||
| smlal v4.4s, v20.4h, v25.4h | |||
| smlal v5.4s, v21.4h, v25.4h | |||
| ld1 {v22.4h}, [x22], x11 | |||
| ld1 {v23.4h}, [x22], x11 | |||
| smlal v6.4s, v22.4h, v25.4h | |||
| smlal v7.4s, v23.4h, v25.4h | |||
| ld1 {v16.4h}, [x22], x11 | |||
| ld1 {v17.4h}, [x22], x11 | |||
| smlal v8.4s, v16.4h, v25.4h | |||
| smlal v9.4s, v17.4h, v25.4h | |||
| ld1 {v18.4h}, [x22], x11 | |||
| ld1 {v19.4h}, [x22], x11 | |||
| smlal v10.4s, v18.4h, v25.4h | |||
| smlal v11.4s, v19.4h, v25.4h | |||
| ld1 {v20.4h}, [x22], x11 | |||
| ld1 {v21.4h}, [x22], x11 | |||
| smlal v12.4s, v20.4h, v25.4h | |||
| smlal v13.4s, v21.4h, v25.4h | |||
| ld1 {v22.4h}, [x22], x11 | |||
| ld1 {v23.4h}, [x22], x11 | |||
| smlal v14.4s, v22.4h, v25.4h | |||
| smlal v15.4s, v23.4h, v25.4h | |||
| subs x18, x18, #1 | |||
| add x21, x21, x13 | |||
| bne LoopKw16 | |||
| add x16, x16, x12 | |||
| subs x20, x20, #1 | |||
| bne LoopKh16 | |||
| sqshl v0.4s, v0.4s, v26.4s | |||
| sqshl v1.4s, v1.4s, v26.4s | |||
| sqshl v2.4s, v2.4s, v26.4s | |||
| sqshl v3.4s, v3.4s, v26.4s | |||
| sqshl v4.4s, v4.4s, v26.4s | |||
| sqshl v5.4s, v5.4s, v26.4s | |||
| sqshl v6.4s, v6.4s, v26.4s | |||
| sqshl v7.4s, v7.4s, v26.4s | |||
| sqshl v8.4s, v8.4s, v26.4s | |||
| sqshl v9.4s, v9.4s, v26.4s | |||
| sqshl v10.4s, v10.4s, v26.4s | |||
| sqshl v11.4s, v11.4s, v26.4s | |||
| sqshl v12.4s, v12.4s, v26.4s | |||
| sqshl v13.4s, v13.4s, v26.4s | |||
| sqshl v14.4s, v14.4s, v26.4s | |||
| sqshl v15.4s, v15.4s, v26.4s | |||
| sqrdmulh v0.4s, v0.4s, v27.4s | |||
| sqrdmulh v1.4s, v1.4s, v27.4s | |||
| sqrdmulh v2.4s, v2.4s, v27.4s | |||
| sqrdmulh v3.4s, v3.4s, v27.4s | |||
| sqrdmulh v4.4s, v4.4s, v27.4s | |||
| sqrdmulh v5.4s, v5.4s, v27.4s | |||
| sqrdmulh v6.4s, v6.4s, v27.4s | |||
| sqrdmulh v7.4s, v7.4s, v27.4s | |||
| sqrdmulh v8.4s, v8.4s, v27.4s | |||
| sqrdmulh v9.4s, v9.4s, v27.4s | |||
| sqrdmulh v10.4s, v10.4s, v27.4s | |||
| sqrdmulh v11.4s, v11.4s, v27.4s | |||
| sqrdmulh v12.4s, v12.4s, v27.4s | |||
| sqrdmulh v13.4s, v13.4s, v27.4s | |||
| sqrdmulh v14.4s, v14.4s, v27.4s | |||
| sqrdmulh v15.4s, v15.4s, v27.4s | |||
| and v16.16b, v28.16b, v0.16b | |||
| sshr v16.4s, v16.4s, #31 | |||
| sqadd v0.4s, v0.4s, v16.4s | |||
| srshl v0.4s, v0.4s, v28.4s | |||
| and v17.16b, v28.16b, v1.16b | |||
| sshr v17.4s, v17.4s, #31 | |||
| sqadd v1.4s, v1.4s, v17.4s | |||
| srshl v1.4s, v1.4s, v28.4s | |||
| and v18.16b, v28.16b, v2.16b | |||
| sshr v18.4s, v18.4s, #31 | |||
| sqadd v2.4s, v2.4s, v18.4s | |||
| srshl v2.4s, v2.4s, v28.4s | |||
| and v19.16b, v28.16b, v3.16b | |||
| sshr v19.4s, v19.4s, #31 | |||
| sqadd v3.4s, v3.4s, v19.4s | |||
| srshl v3.4s, v3.4s, v28.4s | |||
| and v20.16b, v28.16b, v4.16b | |||
| sshr v20.4s, v20.4s, #31 | |||
| sqadd v4.4s, v4.4s, v20.4s | |||
| srshl v4.4s, v4.4s, v28.4s | |||
| and v21.16b, v28.16b, v5.16b | |||
| sshr v21.4s, v21.4s, #31 | |||
| sqadd v5.4s, v5.4s, v21.4s | |||
| srshl v5.4s, v5.4s, v28.4s | |||
| and v22.16b, v28.16b, v6.16b | |||
| sshr v22.4s, v22.4s, #31 | |||
| sqadd v6.4s, v6.4s, v22.4s | |||
| srshl v6.4s, v6.4s, v28.4s | |||
| and v23.16b, v28.16b, v7.16b | |||
| sshr v23.4s, v23.4s, #31 | |||
| sqadd v7.4s, v7.4s, v23.4s | |||
| srshl v7.4s, v7.4s, v28.4s | |||
| and v16.16b, v28.16b, v8.16b | |||
| sshr v16.4s, v16.4s, #31 | |||
| sqadd v8.4s, v8.4s, v16.4s | |||
| srshl v8.4s, v8.4s, v28.4s | |||
| and v17.16b, v28.16b, v9.16b | |||
| sshr v17.4s, v17.4s, #31 | |||
| sqadd v9.4s, v9.4s, v17.4s | |||
| srshl v9.4s, v9.4s, v28.4s | |||
| and v18.16b, v28.16b, v10.16b | |||
| sshr v18.4s, v18.4s, #31 | |||
| sqadd v10.4s, v10.4s, v18.4s | |||
| srshl v10.4s, v10.4s, v28.4s | |||
| and v19.16b, v28.16b, v11.16b | |||
| sshr v19.4s, v19.4s, #31 | |||
| sqadd v11.4s, v11.4s, v19.4s | |||
| srshl v11.4s, v11.4s, v28.4s | |||
| and v20.16b, v28.16b, v12.16b | |||
| sshr v20.4s, v20.4s, #31 | |||
| sqadd v12.4s, v12.4s, v20.4s | |||
| srshl v12.4s, v12.4s, v28.4s | |||
| and v21.16b, v28.16b, v13.16b | |||
| sshr v21.4s, v21.4s, #31 | |||
| sqadd v13.4s, v13.4s, v21.4s | |||
| srshl v13.4s, v13.4s, v28.4s | |||
| and v22.16b, v28.16b, v14.16b | |||
| sshr v22.4s, v22.4s, #31 | |||
| sqadd v14.4s, v14.4s, v22.4s | |||
| srshl v14.4s, v14.4s, v28.4s | |||
| and v23.16b, v28.16b, v15.16b | |||
| sshr v23.4s, v23.4s, #31 | |||
| sqadd v15.4s, v15.4s, v23.4s | |||
| srshl v15.4s, v15.4s, v28.4s | |||
| add v0.4s, v0.4s, v29.4s | |||
| add v1.4s, v1.4s, v29.4s | |||
| add v2.4s, v2.4s, v29.4s | |||
| add v3.4s, v3.4s, v29.4s | |||
| add v4.4s, v4.4s, v29.4s | |||
| add v5.4s, v5.4s, v29.4s | |||
| add v6.4s, v6.4s, v29.4s | |||
| add v7.4s, v7.4s, v29.4s | |||
| add v8.4s, v8.4s, v29.4s | |||
| add v9.4s, v9.4s, v29.4s | |||
| add v10.4s, v10.4s, v29.4s | |||
| add v11.4s, v11.4s, v29.4s | |||
| add v12.4s, v12.4s, v29.4s | |||
| add v13.4s, v13.4s, v29.4s | |||
| add v14.4s, v14.4s, v29.4s | |||
| add v15.4s, v15.4s, v29.4s | |||
| smax v0.4s, v0.4s, v30.4s | |||
| smax v1.4s, v1.4s, v30.4s | |||
| smax v2.4s, v2.4s, v30.4s | |||
| smax v3.4s, v3.4s, v30.4s | |||
| smax v4.4s, v4.4s, v30.4s | |||
| smax v5.4s, v5.4s, v30.4s | |||
| smax v6.4s, v6.4s, v30.4s | |||
| smax v7.4s, v7.4s, v30.4s | |||
| smax v8.4s, v8.4s, v30.4s | |||
| smax v9.4s, v9.4s, v30.4s | |||
| smax v10.4s, v10.4s, v30.4s | |||
| smax v11.4s, v11.4s, v30.4s | |||
| smax v12.4s, v12.4s, v30.4s | |||
| smax v13.4s, v13.4s, v30.4s | |||
| smax v14.4s, v14.4s, v30.4s | |||
| smax v15.4s, v15.4s, v30.4s | |||
| smin v0.4s, v0.4s, v31.4s | |||
| smin v1.4s, v1.4s, v31.4s | |||
| smin v2.4s, v2.4s, v31.4s | |||
| smin v3.4s, v3.4s, v31.4s | |||
| smin v4.4s, v4.4s, v31.4s | |||
| smin v5.4s, v5.4s, v31.4s | |||
| smin v6.4s, v6.4s, v31.4s | |||
| smin v7.4s, v7.4s, v31.4s | |||
| smin v8.4s, v8.4s, v31.4s | |||
| smin v9.4s, v9.4s, v31.4s | |||
| smin v10.4s, v10.4s, v31.4s | |||
| smin v11.4s, v11.4s, v31.4s | |||
| smin v12.4s, v12.4s, v31.4s | |||
| smin v13.4s, v13.4s, v31.4s | |||
| smin v14.4s, v14.4s, v31.4s | |||
| smin v15.4s, v15.4s, v31.4s | |||
| ld1 {v16.8h}, [x17], #16 | |||
| sqxtn v0.4h, v0.4s | |||
| sqxtn v1.4h, v1.4s | |||
| sqxtn v2.4h, v2.4s | |||
| sqxtn v3.4h, v3.4s | |||
| sqxtn v4.4h, v4.4s | |||
| sqxtn v5.4h, v5.4s | |||
| sqxtn v6.4h, v6.4s | |||
| sqxtn v7.4h, v7.4s | |||
| sqxtn v8.4h, v8.4s | |||
| sqxtn v9.4h, v9.4s | |||
| sqxtn v10.4h, v10.4s | |||
| sqxtn v11.4h, v11.4s | |||
| sqxtn v12.4h, v12.4s | |||
| sqxtn v13.4h, v13.4s | |||
| sqxtn v14.4h, v14.4s | |||
| sqxtn v15.4h, v15.4s | |||
| sqxtn v0.8b, v0.8h | |||
| sqxtn v1.8b, v1.8h | |||
| sqxtn v2.8b, v2.8h | |||
| sqxtn v3.8b, v3.8h | |||
| sqxtn v4.8b, v4.8h | |||
| sqxtn v5.8b, v5.8h | |||
| sqxtn v6.8b, v6.8h | |||
| sqxtn v7.8b, v7.8h | |||
| sqxtn v8.8b, v8.8h | |||
| sqxtn v9.8b, v9.8h | |||
| sqxtn v10.8b, v10.8h | |||
| sqxtn v11.8b, v11.8h | |||
| sqxtn v12.8b, v12.8h | |||
| sqxtn v13.8b, v13.8h | |||
| sqxtn v14.8b, v14.8h | |||
| sqxtn v15.8b, v15.8h | |||
| add x17, x3, #1 | |||
| add x18, x3, #2 | |||
| add x21, x3, #3 | |||
| st1 {v0.b}[0], [x3], x9 | |||
| st1 {v0.b}[1], [x17], x9 | |||
| st1 {v0.b}[2], [x18], x9 | |||
| st1 {v0.b}[3], [x21], x9 | |||
| st1 {v1.b}[0], [x3], x9 | |||
| st1 {v1.b}[1], [x17], x9 | |||
| st1 {v1.b}[2], [x18], x9 | |||
| st1 {v1.b}[3], [x21], x9 | |||
| st1 {v2.b}[0], [x3], x9 | |||
| st1 {v2.b}[1], [x17], x9 | |||
| st1 {v2.b}[2], [x18], x9 | |||
| st1 {v2.b}[3], [x21], x9 | |||
| st1 {v3.b}[0], [x3], x9 | |||
| st1 {v3.b}[1], [x17], x9 | |||
| st1 {v3.b}[2], [x18], x9 | |||
| st1 {v3.b}[3], [x21], x9 | |||
| st1 {v4.b}[0], [x3], x9 | |||
| st1 {v4.b}[1], [x17], x9 | |||
| st1 {v4.b}[2], [x18], x9 | |||
| st1 {v4.b}[3], [x21], x9 | |||
| st1 {v5.b}[0], [x3], x9 | |||
| st1 {v5.b}[1], [x17], x9 | |||
| st1 {v5.b}[2], [x18], x9 | |||
| st1 {v5.b}[3], [x21], x9 | |||
| st1 {v6.b}[0], [x3], x9 | |||
| st1 {v6.b}[1], [x17], x9 | |||
| st1 {v6.b}[2], [x18], x9 | |||
| st1 {v6.b}[3], [x21], x9 | |||
| st1 {v7.b}[0], [x3], x9 | |||
| st1 {v7.b}[1], [x17], x9 | |||
| st1 {v7.b}[2], [x18], x9 | |||
| st1 {v7.b}[3], [x21], x9 | |||
| st1 {v8.b}[0], [x3], x9 | |||
| st1 {v8.b}[1], [x17], x9 | |||
| st1 {v8.b}[2], [x18], x9 | |||
| st1 {v8.b}[3], [x21], x9 | |||
| st1 {v9.b}[0], [x3], x9 | |||
| st1 {v9.b}[1], [x17], x9 | |||
| st1 {v9.b}[2], [x18], x9 | |||
| st1 {v9.b}[3], [x21], x9 | |||
| st1 {v10.b}[0], [x3], x9 | |||
| st1 {v10.b}[1], [x17], x9 | |||
| st1 {v10.b}[2], [x18], x9 | |||
| st1 {v10.b}[3], [x21], x9 | |||
| st1 {v11.b}[0], [x3], x9 | |||
| st1 {v11.b}[1], [x17], x9 | |||
| st1 {v11.b}[2], [x18], x9 | |||
| st1 {v11.b}[3], [x21], x9 | |||
| st1 {v12.b}[0], [x3], x9 | |||
| st1 {v12.b}[1], [x17], x9 | |||
| st1 {v12.b}[2], [x18], x9 | |||
| st1 {v12.b}[3], [x21], x9 | |||
| st1 {v13.b}[0], [x3], x9 | |||
| st1 {v13.b}[1], [x17], x9 | |||
| st1 {v13.b}[2], [x18], x9 | |||
| st1 {v13.b}[3], [x21], x9 | |||
| st1 {v14.b}[0], [x3], x9 | |||
| st1 {v14.b}[1], [x17], x9 | |||
| st1 {v14.b}[2], [x18], x9 | |||
| st1 {v14.b}[3], [x21], x9 | |||
| st1 {v15.b}[0], [x3], x9 | |||
| st1 {v15.b}[1], [x17], x9 | |||
| st1 {v15.b}[2], [x18], x9 | |||
| st1 {v15.b}[3], [x21], x9 | |||
| ld1 {v15.8b}, [x22], x11 | |||
| ssubl v14.8h, v15.8b, v19.8b | |||
| smlal v0.4s, v14.4h, v16.4h | |||
| smlal2 v1.4s, v14.8h, v16.8h | |||
| ld1 {v13.8b}, [x22], x11 | |||
| ssubl v12.8h, v13.8b, v19.8b | |||
| smlal v2.4s, v12.4h, v16.4h | |||
| smlal2 v3.4s, v12.8h, v16.8h | |||
| ld1 {v11.8b}, [x22], x11 | |||
| ssubl v10.8h, v11.8b, v19.8b | |||
| smlal v4.4s, v10.4h, v16.4h | |||
| smlal2 v5.4s, v10.8h, v16.8h | |||
| ld1 {v9.8b}, [x22], x11 | |||
| ssubl v8.8h, v9.8b, v19.8b | |||
| smlal v6.4s, v8.4h, v16.4h | |||
| smlal2 v7.4s, v8.8h, v16.8h | |||
| add x23, x23, x19 | |||
| sub x24, x24, #16 | |||
| cmp x24, #0 | |||
| ble LoopWEnd | |||
| cmp x24, #8 | |||
| blt LoopW | |||
| cmp x24, #16 | |||
| bge LoopW16 | |||
| LoopW8: | |||
| mov x19, #8 | |||
| mul x19, x19, x11 | |||
| mov x16, x23 | |||
| mov x17, x2 | |||
| mov x20, x6 | |||
| mov v0.16b, v24.16b | |||
| mov v1.16b, v24.16b | |||
| mov v2.16b, v24.16b | |||
| mov v3.16b, v24.16b | |||
| mov v4.16b, v24.16b | |||
| mov v5.16b, v24.16b | |||
| mov v6.16b, v24.16b | |||
| mov v7.16b, v24.16b | |||
| LoopKh8: | |||
| mov x18, x7 | |||
| mov x21, x16 | |||
| LoopKw8: | |||
| mov x22, x21 | |||
| ld1 {v25.4h}, [x17], #8 | |||
| ld1 {v16.4h}, [x22], x11 | |||
| ld1 {v17.4h}, [x22], x11 | |||
| smlal v0.4s, v16.4h, v25.4h | |||
| smlal v1.4s, v17.4h, v25.4h | |||
| ld1 {v18.4h}, [x22], x11 | |||
| ld1 {v19.4h}, [x22], x11 | |||
| smlal v2.4s, v18.4h, v25.4h | |||
| smlal v3.4s, v19.4h, v25.4h | |||
| ld1 {v20.4h}, [x22], x11 | |||
| ld1 {v21.4h}, [x22], x11 | |||
| smlal v4.4s, v20.4h, v25.4h | |||
| smlal v5.4s, v21.4h, v25.4h | |||
| ld1 {v22.4h}, [x22], x11 | |||
| ld1 {v23.4h}, [x22], x11 | |||
| smlal v6.4s, v22.4h, v25.4h | |||
| smlal v7.4s, v23.4h, v25.4h | |||
| subs x18, x18, #1 | |||
| add x21, x21, x13 | |||
| bne LoopKw8 | |||
| bne LoopKw4 | |||
| add x16, x16, x12 | |||
| subs x20, x20, #1 | |||
| bne LoopKh8 | |||
| sqshl v0.4s, v0.4s, v26.4s | |||
| sqshl v1.4s, v1.4s, v26.4s | |||
| sqshl v2.4s, v2.4s, v26.4s | |||
| sqshl v3.4s, v3.4s, v26.4s | |||
| sqshl v4.4s, v4.4s, v26.4s | |||
| sqshl v5.4s, v5.4s, v26.4s | |||
| sqshl v6.4s, v6.4s, v26.4s | |||
| sqshl v7.4s, v7.4s, v26.4s | |||
| sqrdmulh v0.4s, v0.4s, v27.4s | |||
| sqrdmulh v1.4s, v1.4s, v27.4s | |||
| sqrdmulh v2.4s, v2.4s, v27.4s | |||
| sqrdmulh v3.4s, v3.4s, v27.4s | |||
| sqrdmulh v4.4s, v4.4s, v27.4s | |||
| sqrdmulh v5.4s, v5.4s, v27.4s | |||
| sqrdmulh v6.4s, v6.4s, v27.4s | |||
| sqrdmulh v7.4s, v7.4s, v27.4s | |||
| and v16.16b, v28.16b, v0.16b | |||
| sshr v16.4s, v16.4s, #31 | |||
| sqadd v0.4s, v0.4s, v16.4s | |||
| srshl v0.4s, v0.4s, v28.4s | |||
| and v17.16b, v28.16b, v1.16b | |||
| sshr v17.4s, v17.4s, #31 | |||
| sqadd v1.4s, v1.4s, v17.4s | |||
| srshl v1.4s, v1.4s, v28.4s | |||
| and v18.16b, v28.16b, v2.16b | |||
| sshr v18.4s, v18.4s, #31 | |||
| sqadd v2.4s, v2.4s, v18.4s | |||
| srshl v2.4s, v2.4s, v28.4s | |||
| and v19.16b, v28.16b, v3.16b | |||
| sshr v19.4s, v19.4s, #31 | |||
| sqadd v3.4s, v3.4s, v19.4s | |||
| srshl v3.4s, v3.4s, v28.4s | |||
| and v20.16b, v28.16b, v4.16b | |||
| sshr v20.4s, v20.4s, #31 | |||
| sqadd v4.4s, v4.4s, v20.4s | |||
| srshl v4.4s, v4.4s, v28.4s | |||
| and v21.16b, v28.16b, v5.16b | |||
| sshr v21.4s, v21.4s, #31 | |||
| sqadd v5.4s, v5.4s, v21.4s | |||
| srshl v5.4s, v5.4s, v28.4s | |||
| and v22.16b, v28.16b, v6.16b | |||
| sshr v22.4s, v22.4s, #31 | |||
| sqadd v6.4s, v6.4s, v22.4s | |||
| srshl v6.4s, v6.4s, v28.4s | |||
| and v23.16b, v28.16b, v7.16b | |||
| sshr v23.4s, v23.4s, #31 | |||
| sqadd v7.4s, v7.4s, v23.4s | |||
| srshl v7.4s, v7.4s, v28.4s | |||
| add v0.4s, v0.4s, v29.4s | |||
| add v1.4s, v1.4s, v29.4s | |||
| add v2.4s, v2.4s, v29.4s | |||
| add v3.4s, v3.4s, v29.4s | |||
| add v4.4s, v4.4s, v29.4s | |||
| add v5.4s, v5.4s, v29.4s | |||
| add v6.4s, v6.4s, v29.4s | |||
| add v7.4s, v7.4s, v29.4s | |||
| smax v0.4s, v0.4s, v30.4s | |||
| smax v1.4s, v1.4s, v30.4s | |||
| smax v2.4s, v2.4s, v30.4s | |||
| smax v3.4s, v3.4s, v30.4s | |||
| smax v4.4s, v4.4s, v30.4s | |||
| smax v5.4s, v5.4s, v30.4s | |||
| smax v6.4s, v6.4s, v30.4s | |||
| smax v7.4s, v7.4s, v30.4s | |||
| smin v0.4s, v0.4s, v31.4s | |||
| bne LoopKh4 | |||
| sqshl v0.4s, v0.4s, v24.4s | |||
| sqshl v1.4s, v1.4s, v25.4s | |||
| sqshl v2.4s, v2.4s, v24.4s | |||
| sqshl v3.4s, v3.4s, v25.4s | |||
| sqshl v4.4s, v4.4s, v24.4s | |||
| sqshl v5.4s, v5.4s, v25.4s | |||
| sqshl v6.4s, v6.4s, v24.4s | |||
| sqshl v7.4s, v7.4s, v25.4s | |||
| sqrdmulh v0.4s, v0.4s, v22.4s | |||
| sqrdmulh v1.4s, v1.4s, v23.4s | |||
| sqrdmulh v2.4s, v2.4s, v22.4s | |||
| sqrdmulh v3.4s, v3.4s, v23.4s | |||
| sqrdmulh v4.4s, v4.4s, v22.4s | |||
| sqrdmulh v5.4s, v5.4s, v23.4s | |||
| sqrdmulh v6.4s, v6.4s, v22.4s | |||
| sqrdmulh v7.4s, v7.4s, v23.4s | |||
| and v15.16b, v26.16b, v0.16b | |||
| sshr v15.4s, v15.4s, #31 | |||
| sqadd v0.4s, v0.4s, v15.4s | |||
| srshl v0.4s, v0.4s, v26.4s | |||
| and v14.16b, v27.16b, v1.16b | |||
| sshr v14.4s, v14.4s, #31 | |||
| sqadd v1.4s, v1.4s, v14.4s | |||
| srshl v1.4s, v1.4s, v27.4s | |||
| and v13.16b, v26.16b, v2.16b | |||
| sshr v13.4s, v13.4s, #31 | |||
| sqadd v2.4s, v2.4s, v13.4s | |||
| srshl v2.4s, v2.4s, v26.4s | |||
| and v12.16b, v27.16b, v3.16b | |||
| sshr v12.4s, v12.4s, #31 | |||
| sqadd v3.4s, v3.4s, v12.4s | |||
| srshl v3.4s, v3.4s, v27.4s | |||
| and v11.16b, v26.16b, v4.16b | |||
| sshr v11.4s, v11.4s, #31 | |||
| sqadd v4.4s, v4.4s, v11.4s | |||
| srshl v4.4s, v4.4s, v26.4s | |||
| and v10.16b, v27.16b, v5.16b | |||
| sshr v10.4s, v10.4s, #31 | |||
| sqadd v5.4s, v5.4s, v10.4s | |||
| srshl v5.4s, v5.4s, v27.4s | |||
| and v9.16b, v26.16b, v6.16b | |||
| sshr v9.4s, v9.4s, #31 | |||
| sqadd v6.4s, v6.4s, v9.4s | |||
| srshl v6.4s, v6.4s, v26.4s | |||
| and v8.16b, v27.16b, v7.16b | |||
| sshr v8.4s, v8.4s, #31 | |||
| sqadd v7.4s, v7.4s, v8.4s | |||
| srshl v7.4s, v7.4s, v27.4s | |||
| add v0.4s, v0.4s, v20.4s | |||
| add v1.4s, v1.4s, v21.4s | |||
| add v2.4s, v2.4s, v20.4s | |||
| add v3.4s, v3.4s, v21.4s | |||
| add v4.4s, v4.4s, v20.4s | |||
| add v5.4s, v5.4s, v21.4s | |||
| add v6.4s, v6.4s, v20.4s | |||
| add v7.4s, v7.4s, v21.4s | |||
| smax v0.4s, v0.4s, v28.4s | |||
| smax v1.4s, v1.4s, v29.4s | |||
| smax v2.4s, v2.4s, v28.4s | |||
| smax v3.4s, v3.4s, v29.4s | |||
| smax v4.4s, v4.4s, v28.4s | |||
| smax v5.4s, v5.4s, v29.4s | |||
| smax v6.4s, v6.4s, v28.4s | |||
| smax v7.4s, v7.4s, v29.4s | |||
| smin v0.4s, v0.4s, v30.4s | |||
| smin v1.4s, v1.4s, v31.4s | |||
| smin v2.4s, v2.4s, v31.4s | |||
| smin v2.4s, v2.4s, v30.4s | |||
| smin v3.4s, v3.4s, v31.4s | |||
| smin v4.4s, v4.4s, v31.4s | |||
| smin v4.4s, v4.4s, v30.4s | |||
| smin v5.4s, v5.4s, v31.4s | |||
| smin v6.4s, v6.4s, v31.4s | |||
| smin v6.4s, v6.4s, v30.4s | |||
| smin v7.4s, v7.4s, v31.4s | |||
| sqxtn v0.4h, v0.4s | |||
| @@ -535,93 +222,81 @@ ConvDwInt8Center: | |||
| sqxtn v6.8b, v6.8h | |||
| sqxtn v7.8b, v7.8h | |||
| add x17, x3, #1 | |||
| add x18, x3, #2 | |||
| add x21, x3, #3 | |||
| st1 {v0.b}[0], [x3], x9 | |||
| st1 {v0.b}[1], [x17], x9 | |||
| st1 {v0.b}[2], [x18], x9 | |||
| st1 {v0.b}[3], [x21], x9 | |||
| st1 {v1.b}[0], [x3], x9 | |||
| st1 {v1.b}[1], [x17], x9 | |||
| st1 {v1.b}[2], [x18], x9 | |||
| st1 {v1.b}[3], [x21], x9 | |||
| st1 {v2.b}[0], [x3], x9 | |||
| st1 {v2.b}[1], [x17], x9 | |||
| st1 {v2.b}[2], [x18], x9 | |||
| st1 {v2.b}[3], [x21], x9 | |||
| st1 {v3.b}[0], [x3], x9 | |||
| st1 {v3.b}[1], [x17], x9 | |||
| st1 {v3.b}[2], [x18], x9 | |||
| st1 {v3.b}[3], [x21], x9 | |||
| st1 {v4.b}[0], [x3], x9 | |||
| st1 {v4.b}[1], [x17], x9 | |||
| st1 {v4.b}[2], [x18], x9 | |||
| st1 {v4.b}[3], [x21], x9 | |||
| st1 {v5.b}[0], [x3], x9 | |||
| st1 {v5.b}[1], [x17], x9 | |||
| st1 {v5.b}[2], [x18], x9 | |||
| st1 {v5.b}[3], [x21], x9 | |||
| st1 {v6.b}[0], [x3], x9 | |||
| st1 {v6.b}[1], [x17], x9 | |||
| st1 {v6.b}[2], [x18], x9 | |||
| st1 {v6.b}[3], [x21], x9 | |||
| st1 {v7.b}[0], [x3], x9 | |||
| st1 {v7.b}[1], [x17], x9 | |||
| st1 {v7.b}[2], [x18], x9 | |||
| st1 {v7.b}[3], [x21], x9 | |||
| mov x16, x3 | |||
| add x17, x16, x9 | |||
| add x18, x17, x9 | |||
| add x21, x18, x9 | |||
| st1 {v0.s}[0], [x16], #4 | |||
| st1 {v1.s}[0], [x16], #4 | |||
| st1 {v2.s}[0], [x17], #4 | |||
| st1 {v3.s}[0], [x17], #4 | |||
| st1 {v4.s}[0], [x18], #4 | |||
| st1 {v5.s}[0], [x18], #4 | |||
| st1 {v6.s}[0], [x21], #4 | |||
| st1 {v7.s}[0], [x21], #4 | |||
| add x3, x3, x25 | |||
| add x23, x23, x19 | |||
| sub x24, x24, #8 | |||
| sub x24, x24, #4 | |||
| cmp x24, #0 | |||
| ble LoopWEnd | |||
| cmp x24, #8 | |||
| bge LoopW8 | |||
| cmp x24, #4 | |||
| bge LoopW4 | |||
| LoopW: | |||
| mov x16, x23 | |||
| mov x17, x2 | |||
| mov x20, x6 | |||
| mov v0.16b, v24.16b | |||
| mov v0.16b, v17.16b | |||
| mov v1.16b, v18.16b | |||
| LoopKh: | |||
| mov x18, x7 | |||
| mov x22, x16 | |||
| LoopKw: | |||
| ld1 {v16.4h}, [x22], x13 | |||
| ld1 {v25.4h}, [x17], #8 | |||
| smlal v0.4s, v16.4h, v25.4h | |||
| ld1 {v15.8b}, [x22], x13 | |||
| ssubl v14.8h, v15.8b, v19.8b | |||
| ld1 {v16.8h}, [x17], #16 | |||
| smlal v0.4s, v14.4h, v16.4h | |||
| smlal2 v1.4s, v14.8h, v16.8h | |||
| subs x18, x18, #1 | |||
| bne LoopKw | |||
| add x16, x16, x12 | |||
| subs x20, x20, #1 | |||
| bne LoopKh | |||
| sqshl v0.4s, v0.4s, v26.4s | |||
| sqrdmulh v0.4s, v0.4s, v27.4s | |||
| sqshl v0.4s, v0.4s, v24.4s | |||
| sqrdmulh v0.4s, v0.4s, v22.4s | |||
| sqshl v1.4s, v1.4s, v25.4s | |||
| sqrdmulh v1.4s, v1.4s, v23.4s | |||
| and v16.16b, v28.16b, v0.16b | |||
| sshr v16.4s, v16.4s, #31 | |||
| sqadd v0.4s, v0.4s, v16.4s | |||
| srshl v0.4s, v0.4s, v28.4s | |||
| and v15.16b, v26.16b, v0.16b | |||
| sshr v15.4s, v15.4s, #31 | |||
| sqadd v0.4s, v0.4s, v15.4s | |||
| srshl v0.4s, v0.4s, v26.4s | |||
| add v0.4s, v0.4s, v29.4s | |||
| smax v0.4s, v0.4s, v30.4s | |||
| smin v0.4s, v0.4s, v31.4s | |||
| and v14.16b, v27.16b, v1.16b | |||
| sshr v14.4s, v14.4s, #31 | |||
| sqadd v1.4s, v1.4s, v14.4s | |||
| srshl v1.4s, v1.4s, v27.4s | |||
| add v0.4s, v0.4s, v20.4s | |||
| smax v0.4s, v0.4s, v28.4s | |||
| smin v0.4s, v0.4s, v30.4s | |||
| sqxtn v0.4h, v0.4s | |||
| sqxtn v0.8b, v0.8h | |||
| add v1.4s, v1.4s, v21.4s | |||
| smax v1.4s, v1.4s, v29.4s | |||
| smin v1.4s, v1.4s, v31.4s | |||
| sqxtn v1.4h, v1.4s | |||
| sqxtn v1.8b, v1.8h | |||
| mov x17, x3 | |||
| st1 {v0.b}[0], [x17], #1 | |||
| st1 {v0.b}[1], [x17], #1 | |||
| st1 {v0.b}[2], [x17], #1 | |||
| st1 {v0.b}[3], [x17], #1 | |||
| st1 {v0.s}[0], [x17], #4 | |||
| st1 {v1.s}[0], [x17], #4 | |||
| add x3, x3, x9 | |||
| add x23, x23, x11 | |||
| @@ -45,10 +45,11 @@ void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *wei | |||
| void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, | |||
| size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | |||
| size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); | |||
| void ConvDwInt8Center(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, size_t height, | |||
| void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, | |||
| size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, | |||
| size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int out_multiplier, | |||
| int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max); | |||
| size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp, | |||
| int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, | |||
| int32_t *acc_min, int32_t *acc_max); | |||
| void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, | |||
| int output_channel, int input_step, int8_t input_zp); | |||
| void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, | |||
| @@ -138,75 +138,67 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da | |||
| } | |||
| /*conv depthwise int8 end*/ | |||
| /*conv depthwise sliding window int8 begin*/ | |||
| void DepthwiseBorderPixelInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height, | |||
| int width, int in_kh_step, int in_kw_step, int kernel_w, int *out_multiplier, | |||
| int *left_shift, int *right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max, | |||
| bool per_channel) { | |||
| int tmp_buffer[C4NUM]; | |||
| for (int i = 0; i < C4NUM; i++) { | |||
| /*conv depthwise sliding window perchannel int8 begin*/ | |||
| void DepthwiseBorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, | |||
| int width, int in_kh_step, int in_kw_step, int kernel_w, int8_t *input_zp, | |||
| int32_t *out_zp, int *out_multiplier, int *left_shift, int *right_shift, int32_t *acc_min, | |||
| int32_t *acc_max) { | |||
| int tmp_buffer[C8NUM]; | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| tmp_buffer[i] = 0; | |||
| } | |||
| const int16_t *src_kh = src; | |||
| const int8_t *src_kh = src; | |||
| const int16_t *weight_kh = weight; | |||
| for (int kh = 0; kh < height; kh++) { | |||
| const int16_t *src_kw = src_kh; | |||
| const int8_t *src_kw = src_kh; | |||
| const int16_t *weight_kw = weight_kh; | |||
| for (int kw = 0; kw < width; kw++) { | |||
| for (int c = 0; c < C4NUM; c++) { | |||
| tmp_buffer[c] += src_kw[c] * weight_kw[c]; | |||
| for (int c = 0; c < C8NUM; c++) { | |||
| tmp_buffer[c] += (src_kw[c] - input_zp[c]) * weight_kw[c]; | |||
| } | |||
| src_kw += in_kw_step; | |||
| weight_kw += C4NUM; | |||
| weight_kw += C8NUM; | |||
| } // kernel_w loop | |||
| src_kh += in_kh_step; | |||
| weight_kh += kernel_w * C4NUM; | |||
| weight_kh += kernel_w * C8NUM; | |||
| } // kernel_h loop | |||
| int32_t left = left_shift[0]; | |||
| int32_t right = right_shift[0]; | |||
| int32_t multiplier = out_multiplier[0]; | |||
| for (int c = 0; c < C4NUM; c++) { | |||
| if (per_channel) { | |||
| left = left_shift[c]; | |||
| right = right_shift[c]; | |||
| multiplier = out_multiplier[c]; | |||
| } | |||
| for (int c = 0; c < C8NUM; c++) { | |||
| tmp_buffer[c] += bias[c]; | |||
| tmp_buffer[c] = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left), multiplier), -right); | |||
| tmp_buffer[c] += out_zp; | |||
| tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); | |||
| tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); | |||
| SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), | |||
| -right_shift[c]); | |||
| tmp_buffer[c] += out_zp[c]; | |||
| tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min[c]); | |||
| tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max[c]); | |||
| dst[c] = (tmp_buffer[c]); | |||
| } | |||
| } | |||
| void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int top, | |||
| void DepthwiseBorderInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int top, | |||
| int bottom, int left, int right, const ConvParameter *conv_param, | |||
| const SlidingWindowParam *sliding, int *out_multiplier, int *left_shift, int *right_shift, | |||
| bool per_channel) { | |||
| const SlidingWindowParam *sliding, int8_t *in_zp, int32_t *out_zp, int *out_multiplier, | |||
| int *left_shift, int *right_shift, int32_t *acc_min, int32_t *acc_max) { | |||
| int8_t *dst_h = dst + top * sliding->out_h_step_; | |||
| for (int oh = top; oh < bottom; oh++) { | |||
| int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; | |||
| int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); | |||
| int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); | |||
| const int16_t *src_h = src + ih * sliding->in_h_step_; | |||
| const int8_t *src_h = src + ih * sliding->in_h_step_; | |||
| int8_t *dst_kernel = dst_h + left * sliding->block_channel_; | |||
| for (int ow = left; ow < right; ow++) { | |||
| int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; | |||
| int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); | |||
| int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); | |||
| const int16_t *src_w = src_h + iw * sliding->block_channel_; | |||
| const int8_t *src_w = src_h + iw * sliding->block_channel_; | |||
| const int16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; | |||
| const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; | |||
| const int8_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; | |||
| const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM; | |||
| DepthwiseBorderPixelInt8(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, | |||
| sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, out_multiplier, | |||
| left_shift, right_shift, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | |||
| per_channel); | |||
| sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, in_zp, out_zp, | |||
| out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||
| dst_kernel += sliding->block_channel_; | |||
| } // width loop | |||
| @@ -215,52 +207,46 @@ void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight, | |||
| } | |||
| #ifndef ENABLE_ARM64 | |||
| void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height, | |||
| void DepthwiseCenterInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, | |||
| int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, | |||
| int in_sw_step, int in_kh_step, int in_kw_step, int *out_multiplier, int *left_shift, | |||
| int *right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max, bool per_channel) { | |||
| int tmp_buffer[C4NUM]; | |||
| int in_sw_step, int in_kh_step, int in_kw_step, int8_t *in_zp, int32_t *out_zp, | |||
| int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t *acc_min, | |||
| int32_t *acc_max) { | |||
| int tmp_buffer[C8NUM]; | |||
| int8_t *dst_h = dst; | |||
| const int16_t *src_h = src; | |||
| const int8_t *src_h = src; | |||
| for (int oh = 0; oh < height; oh++) { | |||
| int8_t *dst_w = dst_h; | |||
| const int16_t *src_w = src_h; | |||
| const int8_t *src_w = src_h; | |||
| for (int ow = 0; ow < width; ow++) { | |||
| const int16_t *src_kh = src_w; | |||
| const int8_t *src_kh = src_w; | |||
| const int16_t *weight_kh = weight; | |||
| for (int i = 0; i < C4NUM; i++) { | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| tmp_buffer[i] = 0; | |||
| } | |||
| for (int kh = 0; kh < kernel_h; kh++) { | |||
| const int16_t *src_kw = src_kh; | |||
| const int8_t *src_kw = src_kh; | |||
| const int16_t *weight_kw = weight_kh; | |||
| for (int kw = 0; kw < kernel_w; kw++) { | |||
| for (int c = 0; c < C4NUM; c++) { | |||
| tmp_buffer[c] += src_kw[c] * weight_kw[c]; | |||
| for (int c = 0; c < C8NUM; c++) { | |||
| tmp_buffer[c] += (src_kw[c] - in_zp[c]) * weight_kw[c]; | |||
| } | |||
| src_kw += in_kw_step; | |||
| weight_kw += C4NUM; | |||
| weight_kw += C8NUM; | |||
| } // kernel_w loop | |||
| src_kh += in_kh_step; | |||
| weight_kh += kernel_w * C4NUM; | |||
| weight_kh += kernel_w * C8NUM; | |||
| } // kernel_h loop | |||
| // add bias relu | |||
| int32_t left = left_shift[0]; | |||
| int32_t right = right_shift[0]; | |||
| int32_t multiplier = out_multiplier[0]; | |||
| for (int c = 0; c < C4NUM; c++) { | |||
| if (per_channel) { | |||
| left = left_shift[c]; | |||
| right = right_shift[c]; | |||
| multiplier = out_multiplier[c]; | |||
| } | |||
| for (int c = 0; c < C8NUM; c++) { | |||
| tmp_buffer[c] += bias[c]; | |||
| tmp_buffer[c] = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left), multiplier), -right); | |||
| tmp_buffer[c] += out_zp; | |||
| tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); | |||
| tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); | |||
| SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), | |||
| -right_shift[c]); | |||
| tmp_buffer[c] += out_zp[c]; | |||
| tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min[c]); | |||
| tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max[c]); | |||
| dst_w[c] = (tmp_buffer[c]); | |||
| } | |||
| dst_w += block_channel; | |||
| @@ -272,69 +258,65 @@ void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight, | |||
| } | |||
| #endif | |||
| void ConvDwSWInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data, | |||
| const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) { | |||
| const int16_t *src = input_data; | |||
| void ConvDwSWInt8(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data, | |||
| int8_t *input_zp, int32_t *output_zp, const ConvParameter *conv_param, | |||
| const SlidingWindowParam *sliding, int task_id) { | |||
| const int8_t *src = input_data; | |||
| int8_t *dst = output_data; | |||
| bool per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; | |||
| int *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; | |||
| int *left_shift = conv_param->conv_quant_arg_.left_shift_; | |||
| int *right_shift = conv_param->conv_quant_arg_.right_shift_; | |||
| for (int b = 0; b < conv_param->output_batch_; b++) { | |||
| for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { | |||
| const int16_t *src_data = src + oc * C4NUM; | |||
| int8_t *dst_data = dst + oc * C4NUM; | |||
| const int8_t *src_data = src + oc * C8NUM; | |||
| int8_t *dst_data = dst + oc * C8NUM; | |||
| const int16_t *weight = weight_data + oc * sliding->kernel_step_; | |||
| const int32_t *bias = bias_data + oc * C4NUM; | |||
| const int32_t *bias = bias_data + oc * C8NUM; | |||
| if (per_channel) { | |||
| out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_ + oc * C4NUM; | |||
| left_shift = conv_param->conv_quant_arg_.left_shift_ + oc * C4NUM; | |||
| right_shift = conv_param->conv_quant_arg_.right_shift_ + oc * C4NUM; | |||
| } | |||
| int *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_ + oc * C8NUM; | |||
| int *left_shift = conv_param->conv_quant_arg_.left_shift_ + oc * C8NUM; | |||
| int *right_shift = conv_param->conv_quant_arg_.right_shift_ + oc * C8NUM; | |||
| int *acc_min = conv_param->conv_quant_arg_.out_act_min_ + oc * C8NUM; | |||
| int *acc_max = conv_param->conv_quant_arg_.out_act_max_ + oc * C8NUM; | |||
| int8_t *in_zp = input_zp + oc * C8NUM; | |||
| int32_t *out_zp = output_zp + oc * C8NUM; | |||
| DepthwiseBorderInt8(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, | |||
| sliding, out_multiplier, left_shift, right_shift, per_channel); | |||
| sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||
| DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, | |||
| conv_param->output_w_, conv_param, sliding, out_multiplier, left_shift, right_shift, | |||
| per_channel); | |||
| conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, | |||
| right_shift, acc_min, acc_max); | |||
| DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, | |||
| conv_param, sliding, out_multiplier, left_shift, right_shift, per_channel); | |||
| conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, | |||
| acc_max); | |||
| DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, | |||
| conv_param->output_w_, conv_param, sliding, out_multiplier, left_shift, right_shift, | |||
| per_channel); | |||
| conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, | |||
| right_shift, acc_min, acc_max); | |||
| if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { | |||
| int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; | |||
| int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; | |||
| const int16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; | |||
| const int8_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; | |||
| int8_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; | |||
| #ifdef ENABLE_ARM64 | |||
| ConvDwInt8Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, | |||
| conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(int8_t), | |||
| sliding->block_channel_ * sizeof(int8_t), sliding->in_sh_step_ * sizeof(int16_t), | |||
| sliding->in_sw_step_ * sizeof(int16_t), sliding->in_kh_step_ * sizeof(int16_t), | |||
| sliding->in_kw_step_ * sizeof(int16_t), conv_param->conv_quant_arg_.quant_multiplier_[0], | |||
| conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0], | |||
| conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); | |||
| sliding->block_channel_ * sizeof(int8_t), sliding->in_sh_step_ * sizeof(int8_t), | |||
| sliding->in_sw_step_ * sizeof(int8_t), sliding->in_kh_step_ * sizeof(int8_t), | |||
| sliding->in_kw_step_ * sizeof(int8_t), in_zp, out_zp, out_multiplier, left_shift, right_shift, | |||
| acc_min, acc_max); | |||
| #else | |||
| DepthwiseCenterInt8( | |||
| out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, | |||
| conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, | |||
| sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, out_multiplier, | |||
| left_shift, right_shift, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], per_channel); | |||
| DepthwiseCenterInt8(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, | |||
| sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, | |||
| sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_, | |||
| sliding->in_kh_step_, sliding->in_kw_step_, in_zp, out_zp, out_multiplier, left_shift, | |||
| right_shift, acc_min, acc_max); | |||
| #endif | |||
| } | |||
| } // output C4 loop | |||
| } // output C8 loop | |||
| src += sliding->in_step_; | |||
| dst += sliding->out_step_; | |||
| } // batch loop | |||
| // output nhwc4 | |||
| // output nhwc8 | |||
| } | |||
| /*conv depthwise sliding window int8 end*/ | |||
| /*conv depthwise sliding window perchannel int8 end*/ | |||
| /*deconv depthwise int8 begin*/ | |||
| void DeconvDepthwiseBorderPixelInt8(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width, | |||
| @@ -27,8 +27,9 @@ extern "C" { | |||
| void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data, | |||
| const int32_t *bias_data, const ConvParameter *conv_param, int task_id); | |||
| void ConvDwSWInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data, | |||
| const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); | |||
| void ConvDwSWInt8(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data, | |||
| int8_t *input_zp, int32_t *output_zp, const ConvParameter *conv_param, | |||
| const SlidingWindowParam *sliding, int task_id); | |||
| void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data, | |||
| const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, | |||
| @@ -965,6 +965,45 @@ void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int c | |||
| } | |||
| } | |||
| void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel) { | |||
| int c8 = UP_DIV(channel, C8NUM); | |||
| int nhwc8_batch_unit_offset = c8 * C8NUM * plane; | |||
| int ic_remainder_ = channel % C8NUM; | |||
| if (ic_remainder_ != 0) { | |||
| int nhwc8_batch_offset = 0; | |||
| for (int b = 0; b < batch; b++) { | |||
| int batch_offset = b * channel * plane; | |||
| for (int i = 0; i < plane; i++) { | |||
| memcpy((int8_t *)dst + nhwc8_batch_offset + i * c8 * C8NUM, (int8_t *)src + batch_offset + i * channel, | |||
| channel); | |||
| } | |||
| nhwc8_batch_offset += nhwc8_batch_unit_offset; | |||
| } | |||
| } else { | |||
| size_t ori_input_size = batch * plane * channel; | |||
| memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); | |||
| } | |||
| } | |||
| void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { | |||
| int c8 = UP_DIV(channel, C8NUM); | |||
| int nhwc8_batch_unit_offset = c8 * C8NUM * plane; | |||
| int ic_remainder_ = channel % C8NUM; | |||
| if (ic_remainder_ != 0) { | |||
| for (int b = 0; b < batch; b++) { | |||
| int batch_offset = b * channel * plane; | |||
| int nhwc8_batch_offset = b * nhwc8_batch_unit_offset; | |||
| for (int i = 0; i < plane; i++) { | |||
| memcpy((int8_t *)dst + batch_offset + i * channel, (int8_t *)src + nhwc8_batch_offset + i * c8 * C8NUM, | |||
| channel); | |||
| } | |||
| } | |||
| } else { | |||
| size_t ori_input_size = batch * plane * channel; | |||
| memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); | |||
| } | |||
| } | |||
| void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) { | |||
| int nhwc4_batch_offset = 0; | |||
| int c4 = UP_DIV(channel, C4NUM); | |||
| @@ -1270,6 +1309,25 @@ void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter | |||
| void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, | |||
| ConvQuantArg *quant_qrg) { | |||
| int weight_zp = quant_qrg->filter_quant_args_[0].zp_; | |||
| for (int c = 0; c < channel; c++) { | |||
| if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) { | |||
| weight_zp = quant_qrg->filter_quant_args_[c].zp_; | |||
| } | |||
| int c8_block_num = c / C8NUM; | |||
| int c8_block_rem = c % C8NUM; | |||
| const int8_t *src_c = origin_weight + c * plane; | |||
| int16_t *dst_c = packed_weight_ + c8_block_num * plane * C8NUM; | |||
| for (int k = 0; k < plane; k++) { | |||
| const int8_t *src_kernel = src_c + k; | |||
| int16_t *dst_kernel = dst_c + C8NUM * k + c8_block_rem; | |||
| *dst_kernel = (int16_t)(src_kernel[0] - weight_zp); | |||
| } | |||
| } | |||
| } | |||
| void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, | |||
| ConvQuantArg *quant_qrg) { | |||
| int weight_zp = quant_qrg->filter_quant_args_[0].zp_; | |||
| for (int c = 0; c < channel; c++) { | |||
| if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) { | |||
| weight_zp = quant_qrg->filter_quant_args_[c].zp_; | |||
| @@ -96,6 +96,10 @@ void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int c | |||
| void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNC4HW4ToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); | |||
| @@ -114,6 +118,9 @@ void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter | |||
| void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, | |||
| ConvQuantArg *quant_qrg); | |||
| void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, | |||
| ConvQuantArg *quant_qrg); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -177,8 +177,17 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::Tensor *> | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); | |||
| auto kernel = | |||
| new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| kernel::LiteKernel *kernel; | |||
| auto act_quant_size = | |||
| MSMAX(inputs[kInputIndex]->GetQuantParams().size(), outputs[kOutputIndex]->GetQuantParams().size()); | |||
| if (act_quant_size == 1) { // per tensor | |||
| kernel = new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } else { // per channel | |||
| kernel = | |||
| new (std::nothrow) kernel::ConvolutionDepthwiseSWInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| return nullptr; | |||
| @@ -37,6 +37,7 @@ ConvolutionDepthwiseSWInt8CPUKernel::~ConvolutionDepthwiseSWInt8CPUKernel() { | |||
| free(packed_weight_); | |||
| packed_weight_ = nullptr; | |||
| } | |||
| FreeTmpQuant(); | |||
| FreeQuantParam(); | |||
| } | |||
| @@ -45,8 +46,8 @@ int ConvolutionDepthwiseSWInt8CPUKernel::InitWeightBias() { | |||
| // o, h, w, i -> o/8, h, w, i, 8; o == group, i == 1 | |||
| auto weight_tensor = in_tensors_[kWeightIndex]; | |||
| auto origin_weight = reinterpret_cast<int8_t *>(weight_tensor->MutableData()); | |||
| int OC4 = UP_DIV(weight_tensor->Batch(), C4NUM); | |||
| int pack_weight_size = C4NUM * OC4 * weight_tensor->Height() * weight_tensor->Width(); | |||
| int OC8 = UP_DIV(weight_tensor->Batch(), C8NUM); | |||
| int pack_weight_size = C8NUM * OC8 * weight_tensor->Height() * weight_tensor->Width(); | |||
| packed_weight_ = reinterpret_cast<int16_t *>(malloc(pack_weight_size * sizeof(int16_t))); | |||
| if (packed_weight_ == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||
| @@ -55,35 +56,36 @@ int ConvolutionDepthwiseSWInt8CPUKernel::InitWeightBias() { | |||
| PackDepthwiseInt8Weight(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), | |||
| weight_tensor->Batch(), &(conv_param_->conv_quant_arg_)); | |||
| bias_data_ = reinterpret_cast<int32_t *>(malloc(C4NUM * OC4 * sizeof(int32_t))); | |||
| bias_data_ = reinterpret_cast<int32_t *>(malloc(C8NUM * OC8 * sizeof(int32_t))); | |||
| if (bias_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||
| return RET_ERROR; | |||
| } | |||
| memset(bias_data_, 0, C4NUM * OC4 * sizeof(int32_t)); | |||
| memset(bias_data_, 0, C8NUM * OC8 * sizeof(int32_t)); | |||
| if (in_tensors_.size() == kInputSize2) { | |||
| auto bias_tensor = in_tensors_.at(kBiasIndex); | |||
| auto ori_bias = reinterpret_cast<int32_t *>(bias_tensor->MutableData()); | |||
| memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(int32_t)); | |||
| } | |||
| conv_param_->thread_num_ = MSMIN(thread_count_, OC4); | |||
| conv_param_->thread_num_ = MSMIN(thread_count_, OC8); | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionDepthwiseSWInt8CPUKernel::InitBuffer() { | |||
| int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * | |||
| UP_DIV(conv_param_->input_channel_, 4); | |||
| packed_input_ = reinterpret_cast<int16_t *>(context_->allocator->Malloc(pack_input_size * sizeof(int16_t))); | |||
| if (packed_input_ == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||
| return RET_ERROR; | |||
| } | |||
| if (conv_param_->input_channel_ % C4NUM != 0) { | |||
| if (conv_param_->input_channel_ % C8NUM != 0) { | |||
| need_align_ = true; | |||
| int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * | |||
| UP_DIV(conv_param_->output_channel_, C4NUM); | |||
| int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C8NUM * | |||
| UP_DIV(conv_param_->input_channel_, C8NUM); | |||
| packed_input_ = reinterpret_cast<int8_t *>(context_->allocator->Malloc(pack_input_size * sizeof(int8_t))); | |||
| if (packed_input_ == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||
| return RET_ERROR; | |||
| } | |||
| int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C8NUM * | |||
| UP_DIV(conv_param_->output_channel_, C8NUM); | |||
| packed_output_ = reinterpret_cast<int8_t *>(context_->allocator->Malloc(pack_output_size * sizeof(int8_t))); | |||
| if (packed_input_ == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||
| @@ -93,6 +95,136 @@ int ConvolutionDepthwiseSWInt8CPUKernel::InitBuffer() { | |||
| return RET_OK; | |||
| } | |||
| void ConvolutionDepthwiseSWInt8CPUKernel::FreeTmpQuant() { | |||
| if (input_scale_ != nullptr) { | |||
| free(input_scale_); | |||
| input_scale_ = nullptr; | |||
| } | |||
| if (input_zp_ != nullptr) { | |||
| free(input_zp_); | |||
| input_zp_ = nullptr; | |||
| } | |||
| if (weight_scale_ != nullptr) { | |||
| free(weight_scale_); | |||
| weight_scale_ = nullptr; | |||
| } | |||
| if (output_scale_ != nullptr) { | |||
| free(output_scale_); | |||
| output_scale_ = nullptr; | |||
| } | |||
| if (output_zp_ != nullptr) { | |||
| free(output_zp_); | |||
| output_zp_ = nullptr; | |||
| } | |||
| } | |||
| int ConvolutionDepthwiseSWInt8CPUKernel::ReinitFreeBefore() { | |||
| FreeTmpQuant(); | |||
| if (conv_quant_arg_->real_multiplier_ != nullptr) { | |||
| free(conv_quant_arg_->real_multiplier_); | |||
| conv_quant_arg_->real_multiplier_ = nullptr; | |||
| } | |||
| if (conv_quant_arg_->left_shift_ != nullptr) { | |||
| free(conv_quant_arg_->left_shift_); | |||
| conv_quant_arg_->left_shift_ = nullptr; | |||
| } | |||
| if (conv_quant_arg_->right_shift_ != nullptr) { | |||
| free(conv_quant_arg_->right_shift_); | |||
| conv_quant_arg_->right_shift_ = nullptr; | |||
| } | |||
| if (conv_quant_arg_->quant_multiplier_ != nullptr) { | |||
| free(conv_quant_arg_->quant_multiplier_); | |||
| conv_quant_arg_->quant_multiplier_ = nullptr; | |||
| } | |||
| if (conv_quant_arg_->out_act_min_ != nullptr) { | |||
| free(conv_quant_arg_->out_act_min_); | |||
| conv_quant_arg_->out_act_min_ = nullptr; | |||
| } | |||
| if (conv_quant_arg_->out_act_max_ != nullptr) { | |||
| free(conv_quant_arg_->out_act_max_); | |||
| conv_quant_arg_->out_act_max_ = nullptr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionDepthwiseSWInt8CPUKernel::ReinitQuantParam() { | |||
| ReinitFreeBefore(); // remalloc quant param buffer | |||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||
| auto channel = conv_param_->input_channel_; | |||
| input_scale_ = reinterpret_cast<float *>(malloc(channel * sizeof(float))); | |||
| input_zp_ = reinterpret_cast<int8_t *>(malloc(channel * sizeof(int8_t))); | |||
| if (input_tensor->GetQuantParams().size() == kPerTensor) { | |||
| for (int i = 0; i < channel; i++) { | |||
| auto input_quant_arg = input_tensor->GetQuantParams().front(); | |||
| input_zp_[i] = input_quant_arg.zeroPoint; | |||
| input_scale_[i] = input_quant_arg.scale; | |||
| } | |||
| } else { | |||
| for (int i = 0; i < channel; i++) { | |||
| auto input_quant_arg = input_tensor->GetQuantParams()[i]; | |||
| input_zp_[i] = input_quant_arg.zeroPoint; | |||
| input_scale_[i] = input_quant_arg.scale; | |||
| } | |||
| } | |||
| auto output_tensor = out_tensors_.at(kOutputIndex); | |||
| output_scale_ = reinterpret_cast<float *>(malloc(channel * sizeof(float))); | |||
| output_zp_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t))); | |||
| if (output_tensor->GetQuantParams().size() == kPerTensor) { | |||
| for (int i = 0; i < channel; i++) { | |||
| auto output_quant_arg = output_tensor->GetQuantParams().front(); | |||
| output_zp_[i] = output_quant_arg.zeroPoint; | |||
| output_scale_[i] = output_quant_arg.scale; | |||
| } | |||
| } else { | |||
| for (int i = 0; i < channel; i++) { | |||
| auto output_quant_arg = output_tensor->GetQuantParams()[i]; | |||
| output_zp_[i] = output_quant_arg.zeroPoint; | |||
| output_scale_[i] = output_quant_arg.scale; | |||
| } | |||
| } | |||
| conv_quant_arg_->real_multiplier_ = reinterpret_cast<double *>(malloc(channel * sizeof(double))); | |||
| conv_quant_arg_->left_shift_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t))); | |||
| conv_quant_arg_->right_shift_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t))); | |||
| conv_quant_arg_->quant_multiplier_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t))); | |||
| conv_quant_arg_->out_act_min_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t))); | |||
| conv_quant_arg_->out_act_max_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t))); | |||
| weight_scale_ = reinterpret_cast<float *>(malloc(channel * sizeof(float))); | |||
| auto weight_tensor = in_tensors_.at(kWeightIndex); | |||
| if (weight_tensor->GetQuantParams().size() == kPerTensor) { | |||
| for (int i = 0; i < channel; i++) { | |||
| auto weight_quant_arg = weight_tensor->GetQuantParams().front(); | |||
| weight_scale_[i] = weight_quant_arg.scale; | |||
| } | |||
| } else { | |||
| for (int i = 0; i < channel; i++) { | |||
| auto weight_quant_arg = weight_tensor->GetQuantParams()[i]; | |||
| weight_scale_[i] = weight_quant_arg.scale; | |||
| } | |||
| } | |||
| for (int i = 0; i < channel; ++i) { | |||
| const double in_scale = static_cast<double>(input_scale_[i] * weight_scale_[i]); | |||
| double real_multiplier = in_scale / static_cast<double>(output_scale_[i]); | |||
| conv_quant_arg_->real_multiplier_[i] = real_multiplier; | |||
| QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], &conv_quant_arg_->left_shift_[i], | |||
| &conv_quant_arg_->right_shift_[i]); | |||
| } | |||
| // now only consider per tensor for output | |||
| bool relu = conv_param_->act_type_ == ActType_Relu; | |||
| bool relu6 = conv_param_->act_type_ == ActType_Relu6; | |||
| for (int i = 0; i < channel; ++i) { | |||
| CalculateActivationRangeQuantized(relu, relu6, output_zp_[i], output_scale_[i], | |||
| &conv_param_->conv_quant_arg_.out_act_min_[i], | |||
| &conv_param_->conv_quant_arg_.out_act_max_[i]); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionDepthwiseSWInt8CPUKernel::Init() { | |||
| sliding = new (std::nothrow) SlidingWindowParam; | |||
| if (sliding == nullptr) { | |||
| @@ -107,13 +239,19 @@ int ConvolutionDepthwiseSWInt8CPUKernel::Init() { | |||
| int ConvolutionDepthwiseSWInt8CPUKernel::ReSize() { | |||
| ConvolutionBaseCPUKernel::Init(); | |||
| InitSlidingParamConvDw(sliding, conv_param_, C4NUM); | |||
| InitSlidingParamConvDw(sliding, conv_param_, C8NUM); | |||
| auto ret = ConvolutionBaseCPUKernel::SetQuantParam(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Set quant param failed."; | |||
| return ret; | |||
| } | |||
| ret = ReinitQuantParam(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "reinit quant param failed."; | |||
| return ret; | |||
| } | |||
| ret = InitWeightBias(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Depthwise int8 InitWeightBias error!"; | |||
| @@ -123,8 +261,8 @@ int ConvolutionDepthwiseSWInt8CPUKernel::ReSize() { | |||
| } | |||
| int ConvolutionDepthwiseSWInt8CPUKernel::Execute(int task_id) { | |||
| ConvDwSWInt8(packed_output_, packed_input_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), conv_param_, | |||
| sliding, task_id); | |||
| ConvDwSWInt8(packed_output_, packed_input_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), input_zp_, | |||
| output_zp_, conv_param_, sliding, task_id); | |||
| return RET_OK; | |||
| } | |||
| @@ -157,7 +295,12 @@ int ConvolutionDepthwiseSWInt8CPUKernel::Run() { | |||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||
| auto input_addr = reinterpret_cast<int8_t *>(input_tensor->MutableData()); | |||
| PackDepthwiseInt8Input(input_addr, packed_input_, conv_param_); | |||
| if (need_align_) { | |||
| PackNHWCToNHWC8Int8(input_addr, packed_input_, conv_param_->output_batch_, | |||
| conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); | |||
| } else { | |||
| packed_input_ = input_addr; | |||
| } | |||
| auto output_addr = reinterpret_cast<int8_t *>(out_tensors_.at(kOutputIndex)->MutableData()); | |||
| if (!need_align_) { | |||
| @@ -171,11 +314,11 @@ int ConvolutionDepthwiseSWInt8CPUKernel::Run() { | |||
| } | |||
| if (need_align_) { | |||
| PackNHWC4ToNHWCInt8(packed_output_, output_addr, conv_param_->output_batch_, | |||
| PackNHWC8ToNHWCInt8(packed_output_, output_addr, conv_param_->output_batch_, | |||
| conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); | |||
| context_->allocator->Free(packed_input_); | |||
| context_->allocator->Free(packed_output_); | |||
| } | |||
| context_->allocator->Free(packed_input_); | |||
| return RET_OK; | |||
| } | |||
| @@ -40,11 +40,21 @@ class ConvolutionDepthwiseSWInt8CPUKernel : public ConvolutionBaseCPUKernel { | |||
| int Execute(int task_id); | |||
| private: | |||
| int ReinitQuantParam(); | |||
| int ReinitFreeBefore(); | |||
| void FreeTmpQuant(); | |||
| SlidingWindowParam *sliding = nullptr; | |||
| int16_t *packed_weight_ = nullptr; | |||
| int16_t *packed_input_ = nullptr; | |||
| int8_t *packed_input_ = nullptr; | |||
| int8_t *packed_output_ = nullptr; | |||
| bool need_align_ = false; | |||
| int8_t *input_zp_ = nullptr; | |||
| float *input_scale_ = nullptr; | |||
| float *weight_scale_ = nullptr; | |||
| int32_t *output_zp_ = nullptr; | |||
| float *output_scale_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -52,8 +52,8 @@ int DeconvolutionDepthwiseInt8CPUKernel::InitWeightBias() { | |||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||
| return RET_ERROR; | |||
| } | |||
| PackDepthwiseInt8Weight(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), | |||
| weight_tensor->Batch(), &(conv_param_->conv_quant_arg_)); | |||
| PackDeconvDepthwiseInt8Weight(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), | |||
| weight_tensor->Batch(), &(conv_param_->conv_quant_arg_)); | |||
| bias_data_ = reinterpret_cast<int32_t *>(malloc(C4NUM * OC4 * sizeof(int32_t))); | |||
| if (bias_data_ == nullptr) { | |||