Merge pull request !6038 from yangruoqi713/act_per_channeltags/v1.0.0
| @@ -7,13 +7,15 @@ | |||||
| .type ConvDwInt8Center, %function | .type ConvDwInt8Center, %function | ||||
| #endif | #endif | ||||
| // void ConvDwInt8Center(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, size_t height, size_t width, | |||||
| // size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, | |||||
| // size_t in_kh_step, size_t in_kw_step, int out_multiplier, int left_shift, | |||||
| // int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max); | |||||
| // void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, | |||||
| // size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, | |||||
| // size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp, | |||||
| // int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, | |||||
| // int32_t *acc_min, int32_t *acc_max) | |||||
| // x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: weight, x6: kernel_h, x7: kernel_w, | // x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: weight, x6: kernel_h, x7: kernel_w, | ||||
| // x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step | // x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step | ||||
| // x14: out_multiplier, #56: left_shift, #64: right_shift, #72:out_zp, #80: acc_min, #88: acc_max | |||||
| // x14: in_zp, #56: out_zp, #64: out_multiplier, #72:left_shift, #80: right_shift, #88: acc_min, #96: acc_max | |||||
| ConvDwInt8Center: | ConvDwInt8Center: | ||||
| // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to | // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to | ||||
| // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | ||||
| @@ -33,489 +35,174 @@ ConvDwInt8Center: | |||||
| ldr x12, [sp, #32] | ldr x12, [sp, #32] | ||||
| ldr x13, [sp, #40] | ldr x13, [sp, #40] | ||||
| ldr w14, [sp, #56] | |||||
| dup v26.4s, w14 | |||||
| ldr x14, [sp, #48] // input_zp | |||||
| ld1 {v19.8b}, [x14], #8 | |||||
| ldr x15, [sp, #56] // output_zp | |||||
| ld1 {v20.4s}, [x15], #16 | |||||
| ld1 {v21.4s}, [x15], #16 | |||||
| ldr x15, [sp, #48] | |||||
| dup v27.4s, w15 | |||||
| ldr x16, [sp, #64] // out_multiplier | |||||
| ld1 {v22.4s}, [x16], #16 | |||||
| ld1 {v23.4s}, [x16], #16 | |||||
| ldr w16, [sp, #64] | |||||
| dup v28.4s, w16 | |||||
| ldr x17, [sp, #72] // left_shift | |||||
| ld1 {v24.4s}, [x17], #16 | |||||
| ld1 {v25.4s}, [x17], #16 | |||||
| ldr w17, [sp, #72] | |||||
| dup v29.4s, w17 | |||||
| ldr w18, [sp, #80] | |||||
| dup v30.4s, w18 | |||||
| ldr x18, [sp, #80] // right shift | |||||
| ld1 {v26.4s}, [x18], #16 | |||||
| ld1 {v27.4s}, [x18], #16 | |||||
| ldr w19, [sp, #88] | |||||
| dup v31.4s, w19 | |||||
| ldr x19, [sp, #88] // acc_min | |||||
| ld1 {v28.4s}, [x19], #16 | |||||
| ld1 {v29.4s}, [x19], #16 | |||||
| ld1 {v24.4s}, [x3] | |||||
| ldr x20, [sp, #96] // acc_max | |||||
| ld1 {v30.4s}, [x20], #16 | |||||
| ld1 {v31.4s}, [x20], #16 | |||||
| ld1 {v17.4s}, [x3], #16 | |||||
| ld1 {v18.4s}, [x3], #16 | |||||
| LoopH: | LoopH: | ||||
| mov x23, x1 | mov x23, x1 | ||||
| mov x24, x5 | mov x24, x5 | ||||
| mov x3, x0 | mov x3, x0 | ||||
| cmp x24, #8 | |||||
| blt LoopW | |||||
| cmp x24, #16 | |||||
| blt LoopW8 | |||||
| LoopW16: | |||||
| mov x19, #16 | |||||
| LoopW4: | |||||
| mov x19, #4 | |||||
| mul x19, x19, x11 | mul x19, x19, x11 | ||||
| mov x25, #4 | |||||
| mul x25, x25, x9 | |||||
| mov x16, x23 | mov x16, x23 | ||||
| mov x17, x2 | mov x17, x2 | ||||
| mov x20, x6 | mov x20, x6 | ||||
| mov v0.16b, v24.16b | |||||
| mov v1.16b, v24.16b | |||||
| mov v2.16b, v24.16b | |||||
| mov v3.16b, v24.16b | |||||
| mov v4.16b, v24.16b | |||||
| mov v5.16b, v24.16b | |||||
| mov v6.16b, v24.16b | |||||
| mov v7.16b, v24.16b | |||||
| mov v8.16b, v24.16b | |||||
| mov v9.16b, v24.16b | |||||
| mov v10.16b, v24.16b | |||||
| mov v11.16b, v24.16b | |||||
| mov v12.16b, v24.16b | |||||
| mov v13.16b, v24.16b | |||||
| mov v14.16b, v24.16b | |||||
| mov v15.16b, v24.16b | |||||
| LoopKh16: | |||||
| mov v0.16b, v17.16b | |||||
| mov v1.16b, v18.16b | |||||
| mov v2.16b, v17.16b | |||||
| mov v3.16b, v18.16b | |||||
| mov v4.16b, v17.16b | |||||
| mov v5.16b, v18.16b | |||||
| mov v6.16b, v17.16b | |||||
| mov v7.16b, v18.16b | |||||
| LoopKh4: | |||||
| mov x18, x7 | mov x18, x7 | ||||
| mov x21, x16 | mov x21, x16 | ||||
| LoopKw16: | |||||
| LoopKw4: | |||||
| mov x22, x21 | mov x22, x21 | ||||
| ld1 {v25.4h}, [x17], #8 | |||||
| ld1 {v16.4h}, [x22], x11 | |||||
| ld1 {v17.4h}, [x22], x11 | |||||
| smlal v0.4s, v16.4h, v25.4h | |||||
| smlal v1.4s, v17.4h, v25.4h | |||||
| ld1 {v18.4h}, [x22], x11 | |||||
| ld1 {v19.4h}, [x22], x11 | |||||
| smlal v2.4s, v18.4h, v25.4h | |||||
| smlal v3.4s, v19.4h, v25.4h | |||||
| ld1 {v20.4h}, [x22], x11 | |||||
| ld1 {v21.4h}, [x22], x11 | |||||
| smlal v4.4s, v20.4h, v25.4h | |||||
| smlal v5.4s, v21.4h, v25.4h | |||||
| ld1 {v22.4h}, [x22], x11 | |||||
| ld1 {v23.4h}, [x22], x11 | |||||
| smlal v6.4s, v22.4h, v25.4h | |||||
| smlal v7.4s, v23.4h, v25.4h | |||||
| ld1 {v16.4h}, [x22], x11 | |||||
| ld1 {v17.4h}, [x22], x11 | |||||
| smlal v8.4s, v16.4h, v25.4h | |||||
| smlal v9.4s, v17.4h, v25.4h | |||||
| ld1 {v18.4h}, [x22], x11 | |||||
| ld1 {v19.4h}, [x22], x11 | |||||
| smlal v10.4s, v18.4h, v25.4h | |||||
| smlal v11.4s, v19.4h, v25.4h | |||||
| ld1 {v20.4h}, [x22], x11 | |||||
| ld1 {v21.4h}, [x22], x11 | |||||
| smlal v12.4s, v20.4h, v25.4h | |||||
| smlal v13.4s, v21.4h, v25.4h | |||||
| ld1 {v22.4h}, [x22], x11 | |||||
| ld1 {v23.4h}, [x22], x11 | |||||
| smlal v14.4s, v22.4h, v25.4h | |||||
| smlal v15.4s, v23.4h, v25.4h | |||||
| subs x18, x18, #1 | |||||
| add x21, x21, x13 | |||||
| bne LoopKw16 | |||||
| add x16, x16, x12 | |||||
| subs x20, x20, #1 | |||||
| bne LoopKh16 | |||||
| sqshl v0.4s, v0.4s, v26.4s | |||||
| sqshl v1.4s, v1.4s, v26.4s | |||||
| sqshl v2.4s, v2.4s, v26.4s | |||||
| sqshl v3.4s, v3.4s, v26.4s | |||||
| sqshl v4.4s, v4.4s, v26.4s | |||||
| sqshl v5.4s, v5.4s, v26.4s | |||||
| sqshl v6.4s, v6.4s, v26.4s | |||||
| sqshl v7.4s, v7.4s, v26.4s | |||||
| sqshl v8.4s, v8.4s, v26.4s | |||||
| sqshl v9.4s, v9.4s, v26.4s | |||||
| sqshl v10.4s, v10.4s, v26.4s | |||||
| sqshl v11.4s, v11.4s, v26.4s | |||||
| sqshl v12.4s, v12.4s, v26.4s | |||||
| sqshl v13.4s, v13.4s, v26.4s | |||||
| sqshl v14.4s, v14.4s, v26.4s | |||||
| sqshl v15.4s, v15.4s, v26.4s | |||||
| sqrdmulh v0.4s, v0.4s, v27.4s | |||||
| sqrdmulh v1.4s, v1.4s, v27.4s | |||||
| sqrdmulh v2.4s, v2.4s, v27.4s | |||||
| sqrdmulh v3.4s, v3.4s, v27.4s | |||||
| sqrdmulh v4.4s, v4.4s, v27.4s | |||||
| sqrdmulh v5.4s, v5.4s, v27.4s | |||||
| sqrdmulh v6.4s, v6.4s, v27.4s | |||||
| sqrdmulh v7.4s, v7.4s, v27.4s | |||||
| sqrdmulh v8.4s, v8.4s, v27.4s | |||||
| sqrdmulh v9.4s, v9.4s, v27.4s | |||||
| sqrdmulh v10.4s, v10.4s, v27.4s | |||||
| sqrdmulh v11.4s, v11.4s, v27.4s | |||||
| sqrdmulh v12.4s, v12.4s, v27.4s | |||||
| sqrdmulh v13.4s, v13.4s, v27.4s | |||||
| sqrdmulh v14.4s, v14.4s, v27.4s | |||||
| sqrdmulh v15.4s, v15.4s, v27.4s | |||||
| and v16.16b, v28.16b, v0.16b | |||||
| sshr v16.4s, v16.4s, #31 | |||||
| sqadd v0.4s, v0.4s, v16.4s | |||||
| srshl v0.4s, v0.4s, v28.4s | |||||
| and v17.16b, v28.16b, v1.16b | |||||
| sshr v17.4s, v17.4s, #31 | |||||
| sqadd v1.4s, v1.4s, v17.4s | |||||
| srshl v1.4s, v1.4s, v28.4s | |||||
| and v18.16b, v28.16b, v2.16b | |||||
| sshr v18.4s, v18.4s, #31 | |||||
| sqadd v2.4s, v2.4s, v18.4s | |||||
| srshl v2.4s, v2.4s, v28.4s | |||||
| and v19.16b, v28.16b, v3.16b | |||||
| sshr v19.4s, v19.4s, #31 | |||||
| sqadd v3.4s, v3.4s, v19.4s | |||||
| srshl v3.4s, v3.4s, v28.4s | |||||
| and v20.16b, v28.16b, v4.16b | |||||
| sshr v20.4s, v20.4s, #31 | |||||
| sqadd v4.4s, v4.4s, v20.4s | |||||
| srshl v4.4s, v4.4s, v28.4s | |||||
| and v21.16b, v28.16b, v5.16b | |||||
| sshr v21.4s, v21.4s, #31 | |||||
| sqadd v5.4s, v5.4s, v21.4s | |||||
| srshl v5.4s, v5.4s, v28.4s | |||||
| and v22.16b, v28.16b, v6.16b | |||||
| sshr v22.4s, v22.4s, #31 | |||||
| sqadd v6.4s, v6.4s, v22.4s | |||||
| srshl v6.4s, v6.4s, v28.4s | |||||
| and v23.16b, v28.16b, v7.16b | |||||
| sshr v23.4s, v23.4s, #31 | |||||
| sqadd v7.4s, v7.4s, v23.4s | |||||
| srshl v7.4s, v7.4s, v28.4s | |||||
| and v16.16b, v28.16b, v8.16b | |||||
| sshr v16.4s, v16.4s, #31 | |||||
| sqadd v8.4s, v8.4s, v16.4s | |||||
| srshl v8.4s, v8.4s, v28.4s | |||||
| and v17.16b, v28.16b, v9.16b | |||||
| sshr v17.4s, v17.4s, #31 | |||||
| sqadd v9.4s, v9.4s, v17.4s | |||||
| srshl v9.4s, v9.4s, v28.4s | |||||
| and v18.16b, v28.16b, v10.16b | |||||
| sshr v18.4s, v18.4s, #31 | |||||
| sqadd v10.4s, v10.4s, v18.4s | |||||
| srshl v10.4s, v10.4s, v28.4s | |||||
| and v19.16b, v28.16b, v11.16b | |||||
| sshr v19.4s, v19.4s, #31 | |||||
| sqadd v11.4s, v11.4s, v19.4s | |||||
| srshl v11.4s, v11.4s, v28.4s | |||||
| and v20.16b, v28.16b, v12.16b | |||||
| sshr v20.4s, v20.4s, #31 | |||||
| sqadd v12.4s, v12.4s, v20.4s | |||||
| srshl v12.4s, v12.4s, v28.4s | |||||
| and v21.16b, v28.16b, v13.16b | |||||
| sshr v21.4s, v21.4s, #31 | |||||
| sqadd v13.4s, v13.4s, v21.4s | |||||
| srshl v13.4s, v13.4s, v28.4s | |||||
| and v22.16b, v28.16b, v14.16b | |||||
| sshr v22.4s, v22.4s, #31 | |||||
| sqadd v14.4s, v14.4s, v22.4s | |||||
| srshl v14.4s, v14.4s, v28.4s | |||||
| and v23.16b, v28.16b, v15.16b | |||||
| sshr v23.4s, v23.4s, #31 | |||||
| sqadd v15.4s, v15.4s, v23.4s | |||||
| srshl v15.4s, v15.4s, v28.4s | |||||
| add v0.4s, v0.4s, v29.4s | |||||
| add v1.4s, v1.4s, v29.4s | |||||
| add v2.4s, v2.4s, v29.4s | |||||
| add v3.4s, v3.4s, v29.4s | |||||
| add v4.4s, v4.4s, v29.4s | |||||
| add v5.4s, v5.4s, v29.4s | |||||
| add v6.4s, v6.4s, v29.4s | |||||
| add v7.4s, v7.4s, v29.4s | |||||
| add v8.4s, v8.4s, v29.4s | |||||
| add v9.4s, v9.4s, v29.4s | |||||
| add v10.4s, v10.4s, v29.4s | |||||
| add v11.4s, v11.4s, v29.4s | |||||
| add v12.4s, v12.4s, v29.4s | |||||
| add v13.4s, v13.4s, v29.4s | |||||
| add v14.4s, v14.4s, v29.4s | |||||
| add v15.4s, v15.4s, v29.4s | |||||
| smax v0.4s, v0.4s, v30.4s | |||||
| smax v1.4s, v1.4s, v30.4s | |||||
| smax v2.4s, v2.4s, v30.4s | |||||
| smax v3.4s, v3.4s, v30.4s | |||||
| smax v4.4s, v4.4s, v30.4s | |||||
| smax v5.4s, v5.4s, v30.4s | |||||
| smax v6.4s, v6.4s, v30.4s | |||||
| smax v7.4s, v7.4s, v30.4s | |||||
| smax v8.4s, v8.4s, v30.4s | |||||
| smax v9.4s, v9.4s, v30.4s | |||||
| smax v10.4s, v10.4s, v30.4s | |||||
| smax v11.4s, v11.4s, v30.4s | |||||
| smax v12.4s, v12.4s, v30.4s | |||||
| smax v13.4s, v13.4s, v30.4s | |||||
| smax v14.4s, v14.4s, v30.4s | |||||
| smax v15.4s, v15.4s, v30.4s | |||||
| smin v0.4s, v0.4s, v31.4s | |||||
| smin v1.4s, v1.4s, v31.4s | |||||
| smin v2.4s, v2.4s, v31.4s | |||||
| smin v3.4s, v3.4s, v31.4s | |||||
| smin v4.4s, v4.4s, v31.4s | |||||
| smin v5.4s, v5.4s, v31.4s | |||||
| smin v6.4s, v6.4s, v31.4s | |||||
| smin v7.4s, v7.4s, v31.4s | |||||
| smin v8.4s, v8.4s, v31.4s | |||||
| smin v9.4s, v9.4s, v31.4s | |||||
| smin v10.4s, v10.4s, v31.4s | |||||
| smin v11.4s, v11.4s, v31.4s | |||||
| smin v12.4s, v12.4s, v31.4s | |||||
| smin v13.4s, v13.4s, v31.4s | |||||
| smin v14.4s, v14.4s, v31.4s | |||||
| smin v15.4s, v15.4s, v31.4s | |||||
| ld1 {v16.8h}, [x17], #16 | |||||
| sqxtn v0.4h, v0.4s | |||||
| sqxtn v1.4h, v1.4s | |||||
| sqxtn v2.4h, v2.4s | |||||
| sqxtn v3.4h, v3.4s | |||||
| sqxtn v4.4h, v4.4s | |||||
| sqxtn v5.4h, v5.4s | |||||
| sqxtn v6.4h, v6.4s | |||||
| sqxtn v7.4h, v7.4s | |||||
| sqxtn v8.4h, v8.4s | |||||
| sqxtn v9.4h, v9.4s | |||||
| sqxtn v10.4h, v10.4s | |||||
| sqxtn v11.4h, v11.4s | |||||
| sqxtn v12.4h, v12.4s | |||||
| sqxtn v13.4h, v13.4s | |||||
| sqxtn v14.4h, v14.4s | |||||
| sqxtn v15.4h, v15.4s | |||||
| sqxtn v0.8b, v0.8h | |||||
| sqxtn v1.8b, v1.8h | |||||
| sqxtn v2.8b, v2.8h | |||||
| sqxtn v3.8b, v3.8h | |||||
| sqxtn v4.8b, v4.8h | |||||
| sqxtn v5.8b, v5.8h | |||||
| sqxtn v6.8b, v6.8h | |||||
| sqxtn v7.8b, v7.8h | |||||
| sqxtn v8.8b, v8.8h | |||||
| sqxtn v9.8b, v9.8h | |||||
| sqxtn v10.8b, v10.8h | |||||
| sqxtn v11.8b, v11.8h | |||||
| sqxtn v12.8b, v12.8h | |||||
| sqxtn v13.8b, v13.8h | |||||
| sqxtn v14.8b, v14.8h | |||||
| sqxtn v15.8b, v15.8h | |||||
| add x17, x3, #1 | |||||
| add x18, x3, #2 | |||||
| add x21, x3, #3 | |||||
| st1 {v0.b}[0], [x3], x9 | |||||
| st1 {v0.b}[1], [x17], x9 | |||||
| st1 {v0.b}[2], [x18], x9 | |||||
| st1 {v0.b}[3], [x21], x9 | |||||
| st1 {v1.b}[0], [x3], x9 | |||||
| st1 {v1.b}[1], [x17], x9 | |||||
| st1 {v1.b}[2], [x18], x9 | |||||
| st1 {v1.b}[3], [x21], x9 | |||||
| st1 {v2.b}[0], [x3], x9 | |||||
| st1 {v2.b}[1], [x17], x9 | |||||
| st1 {v2.b}[2], [x18], x9 | |||||
| st1 {v2.b}[3], [x21], x9 | |||||
| st1 {v3.b}[0], [x3], x9 | |||||
| st1 {v3.b}[1], [x17], x9 | |||||
| st1 {v3.b}[2], [x18], x9 | |||||
| st1 {v3.b}[3], [x21], x9 | |||||
| st1 {v4.b}[0], [x3], x9 | |||||
| st1 {v4.b}[1], [x17], x9 | |||||
| st1 {v4.b}[2], [x18], x9 | |||||
| st1 {v4.b}[3], [x21], x9 | |||||
| st1 {v5.b}[0], [x3], x9 | |||||
| st1 {v5.b}[1], [x17], x9 | |||||
| st1 {v5.b}[2], [x18], x9 | |||||
| st1 {v5.b}[3], [x21], x9 | |||||
| st1 {v6.b}[0], [x3], x9 | |||||
| st1 {v6.b}[1], [x17], x9 | |||||
| st1 {v6.b}[2], [x18], x9 | |||||
| st1 {v6.b}[3], [x21], x9 | |||||
| st1 {v7.b}[0], [x3], x9 | |||||
| st1 {v7.b}[1], [x17], x9 | |||||
| st1 {v7.b}[2], [x18], x9 | |||||
| st1 {v7.b}[3], [x21], x9 | |||||
| st1 {v8.b}[0], [x3], x9 | |||||
| st1 {v8.b}[1], [x17], x9 | |||||
| st1 {v8.b}[2], [x18], x9 | |||||
| st1 {v8.b}[3], [x21], x9 | |||||
| st1 {v9.b}[0], [x3], x9 | |||||
| st1 {v9.b}[1], [x17], x9 | |||||
| st1 {v9.b}[2], [x18], x9 | |||||
| st1 {v9.b}[3], [x21], x9 | |||||
| st1 {v10.b}[0], [x3], x9 | |||||
| st1 {v10.b}[1], [x17], x9 | |||||
| st1 {v10.b}[2], [x18], x9 | |||||
| st1 {v10.b}[3], [x21], x9 | |||||
| st1 {v11.b}[0], [x3], x9 | |||||
| st1 {v11.b}[1], [x17], x9 | |||||
| st1 {v11.b}[2], [x18], x9 | |||||
| st1 {v11.b}[3], [x21], x9 | |||||
| st1 {v12.b}[0], [x3], x9 | |||||
| st1 {v12.b}[1], [x17], x9 | |||||
| st1 {v12.b}[2], [x18], x9 | |||||
| st1 {v12.b}[3], [x21], x9 | |||||
| st1 {v13.b}[0], [x3], x9 | |||||
| st1 {v13.b}[1], [x17], x9 | |||||
| st1 {v13.b}[2], [x18], x9 | |||||
| st1 {v13.b}[3], [x21], x9 | |||||
| st1 {v14.b}[0], [x3], x9 | |||||
| st1 {v14.b}[1], [x17], x9 | |||||
| st1 {v14.b}[2], [x18], x9 | |||||
| st1 {v14.b}[3], [x21], x9 | |||||
| st1 {v15.b}[0], [x3], x9 | |||||
| st1 {v15.b}[1], [x17], x9 | |||||
| st1 {v15.b}[2], [x18], x9 | |||||
| st1 {v15.b}[3], [x21], x9 | |||||
| ld1 {v15.8b}, [x22], x11 | |||||
| ssubl v14.8h, v15.8b, v19.8b | |||||
| smlal v0.4s, v14.4h, v16.4h | |||||
| smlal2 v1.4s, v14.8h, v16.8h | |||||
| ld1 {v13.8b}, [x22], x11 | |||||
| ssubl v12.8h, v13.8b, v19.8b | |||||
| smlal v2.4s, v12.4h, v16.4h | |||||
| smlal2 v3.4s, v12.8h, v16.8h | |||||
| ld1 {v11.8b}, [x22], x11 | |||||
| ssubl v10.8h, v11.8b, v19.8b | |||||
| smlal v4.4s, v10.4h, v16.4h | |||||
| smlal2 v5.4s, v10.8h, v16.8h | |||||
| ld1 {v9.8b}, [x22], x11 | |||||
| ssubl v8.8h, v9.8b, v19.8b | |||||
| smlal v6.4s, v8.4h, v16.4h | |||||
| smlal2 v7.4s, v8.8h, v16.8h | |||||
| add x23, x23, x19 | |||||
| sub x24, x24, #16 | |||||
| cmp x24, #0 | |||||
| ble LoopWEnd | |||||
| cmp x24, #8 | |||||
| blt LoopW | |||||
| cmp x24, #16 | |||||
| bge LoopW16 | |||||
| LoopW8: | |||||
| mov x19, #8 | |||||
| mul x19, x19, x11 | |||||
| mov x16, x23 | |||||
| mov x17, x2 | |||||
| mov x20, x6 | |||||
| mov v0.16b, v24.16b | |||||
| mov v1.16b, v24.16b | |||||
| mov v2.16b, v24.16b | |||||
| mov v3.16b, v24.16b | |||||
| mov v4.16b, v24.16b | |||||
| mov v5.16b, v24.16b | |||||
| mov v6.16b, v24.16b | |||||
| mov v7.16b, v24.16b | |||||
| LoopKh8: | |||||
| mov x18, x7 | |||||
| mov x21, x16 | |||||
| LoopKw8: | |||||
| mov x22, x21 | |||||
| ld1 {v25.4h}, [x17], #8 | |||||
| ld1 {v16.4h}, [x22], x11 | |||||
| ld1 {v17.4h}, [x22], x11 | |||||
| smlal v0.4s, v16.4h, v25.4h | |||||
| smlal v1.4s, v17.4h, v25.4h | |||||
| ld1 {v18.4h}, [x22], x11 | |||||
| ld1 {v19.4h}, [x22], x11 | |||||
| smlal v2.4s, v18.4h, v25.4h | |||||
| smlal v3.4s, v19.4h, v25.4h | |||||
| ld1 {v20.4h}, [x22], x11 | |||||
| ld1 {v21.4h}, [x22], x11 | |||||
| smlal v4.4s, v20.4h, v25.4h | |||||
| smlal v5.4s, v21.4h, v25.4h | |||||
| ld1 {v22.4h}, [x22], x11 | |||||
| ld1 {v23.4h}, [x22], x11 | |||||
| smlal v6.4s, v22.4h, v25.4h | |||||
| smlal v7.4s, v23.4h, v25.4h | |||||
| subs x18, x18, #1 | subs x18, x18, #1 | ||||
| add x21, x21, x13 | add x21, x21, x13 | ||||
| bne LoopKw8 | |||||
| bne LoopKw4 | |||||
| add x16, x16, x12 | add x16, x16, x12 | ||||
| subs x20, x20, #1 | subs x20, x20, #1 | ||||
| bne LoopKh8 | |||||
| sqshl v0.4s, v0.4s, v26.4s | |||||
| sqshl v1.4s, v1.4s, v26.4s | |||||
| sqshl v2.4s, v2.4s, v26.4s | |||||
| sqshl v3.4s, v3.4s, v26.4s | |||||
| sqshl v4.4s, v4.4s, v26.4s | |||||
| sqshl v5.4s, v5.4s, v26.4s | |||||
| sqshl v6.4s, v6.4s, v26.4s | |||||
| sqshl v7.4s, v7.4s, v26.4s | |||||
| sqrdmulh v0.4s, v0.4s, v27.4s | |||||
| sqrdmulh v1.4s, v1.4s, v27.4s | |||||
| sqrdmulh v2.4s, v2.4s, v27.4s | |||||
| sqrdmulh v3.4s, v3.4s, v27.4s | |||||
| sqrdmulh v4.4s, v4.4s, v27.4s | |||||
| sqrdmulh v5.4s, v5.4s, v27.4s | |||||
| sqrdmulh v6.4s, v6.4s, v27.4s | |||||
| sqrdmulh v7.4s, v7.4s, v27.4s | |||||
| and v16.16b, v28.16b, v0.16b | |||||
| sshr v16.4s, v16.4s, #31 | |||||
| sqadd v0.4s, v0.4s, v16.4s | |||||
| srshl v0.4s, v0.4s, v28.4s | |||||
| and v17.16b, v28.16b, v1.16b | |||||
| sshr v17.4s, v17.4s, #31 | |||||
| sqadd v1.4s, v1.4s, v17.4s | |||||
| srshl v1.4s, v1.4s, v28.4s | |||||
| and v18.16b, v28.16b, v2.16b | |||||
| sshr v18.4s, v18.4s, #31 | |||||
| sqadd v2.4s, v2.4s, v18.4s | |||||
| srshl v2.4s, v2.4s, v28.4s | |||||
| and v19.16b, v28.16b, v3.16b | |||||
| sshr v19.4s, v19.4s, #31 | |||||
| sqadd v3.4s, v3.4s, v19.4s | |||||
| srshl v3.4s, v3.4s, v28.4s | |||||
| and v20.16b, v28.16b, v4.16b | |||||
| sshr v20.4s, v20.4s, #31 | |||||
| sqadd v4.4s, v4.4s, v20.4s | |||||
| srshl v4.4s, v4.4s, v28.4s | |||||
| and v21.16b, v28.16b, v5.16b | |||||
| sshr v21.4s, v21.4s, #31 | |||||
| sqadd v5.4s, v5.4s, v21.4s | |||||
| srshl v5.4s, v5.4s, v28.4s | |||||
| and v22.16b, v28.16b, v6.16b | |||||
| sshr v22.4s, v22.4s, #31 | |||||
| sqadd v6.4s, v6.4s, v22.4s | |||||
| srshl v6.4s, v6.4s, v28.4s | |||||
| and v23.16b, v28.16b, v7.16b | |||||
| sshr v23.4s, v23.4s, #31 | |||||
| sqadd v7.4s, v7.4s, v23.4s | |||||
| srshl v7.4s, v7.4s, v28.4s | |||||
| add v0.4s, v0.4s, v29.4s | |||||
| add v1.4s, v1.4s, v29.4s | |||||
| add v2.4s, v2.4s, v29.4s | |||||
| add v3.4s, v3.4s, v29.4s | |||||
| add v4.4s, v4.4s, v29.4s | |||||
| add v5.4s, v5.4s, v29.4s | |||||
| add v6.4s, v6.4s, v29.4s | |||||
| add v7.4s, v7.4s, v29.4s | |||||
| smax v0.4s, v0.4s, v30.4s | |||||
| smax v1.4s, v1.4s, v30.4s | |||||
| smax v2.4s, v2.4s, v30.4s | |||||
| smax v3.4s, v3.4s, v30.4s | |||||
| smax v4.4s, v4.4s, v30.4s | |||||
| smax v5.4s, v5.4s, v30.4s | |||||
| smax v6.4s, v6.4s, v30.4s | |||||
| smax v7.4s, v7.4s, v30.4s | |||||
| smin v0.4s, v0.4s, v31.4s | |||||
| bne LoopKh4 | |||||
| sqshl v0.4s, v0.4s, v24.4s | |||||
| sqshl v1.4s, v1.4s, v25.4s | |||||
| sqshl v2.4s, v2.4s, v24.4s | |||||
| sqshl v3.4s, v3.4s, v25.4s | |||||
| sqshl v4.4s, v4.4s, v24.4s | |||||
| sqshl v5.4s, v5.4s, v25.4s | |||||
| sqshl v6.4s, v6.4s, v24.4s | |||||
| sqshl v7.4s, v7.4s, v25.4s | |||||
| sqrdmulh v0.4s, v0.4s, v22.4s | |||||
| sqrdmulh v1.4s, v1.4s, v23.4s | |||||
| sqrdmulh v2.4s, v2.4s, v22.4s | |||||
| sqrdmulh v3.4s, v3.4s, v23.4s | |||||
| sqrdmulh v4.4s, v4.4s, v22.4s | |||||
| sqrdmulh v5.4s, v5.4s, v23.4s | |||||
| sqrdmulh v6.4s, v6.4s, v22.4s | |||||
| sqrdmulh v7.4s, v7.4s, v23.4s | |||||
| and v15.16b, v26.16b, v0.16b | |||||
| sshr v15.4s, v15.4s, #31 | |||||
| sqadd v0.4s, v0.4s, v15.4s | |||||
| srshl v0.4s, v0.4s, v26.4s | |||||
| and v14.16b, v27.16b, v1.16b | |||||
| sshr v14.4s, v14.4s, #31 | |||||
| sqadd v1.4s, v1.4s, v14.4s | |||||
| srshl v1.4s, v1.4s, v27.4s | |||||
| and v13.16b, v26.16b, v2.16b | |||||
| sshr v13.4s, v13.4s, #31 | |||||
| sqadd v2.4s, v2.4s, v13.4s | |||||
| srshl v2.4s, v2.4s, v26.4s | |||||
| and v12.16b, v27.16b, v3.16b | |||||
| sshr v12.4s, v12.4s, #31 | |||||
| sqadd v3.4s, v3.4s, v12.4s | |||||
| srshl v3.4s, v3.4s, v27.4s | |||||
| and v11.16b, v26.16b, v4.16b | |||||
| sshr v11.4s, v11.4s, #31 | |||||
| sqadd v4.4s, v4.4s, v11.4s | |||||
| srshl v4.4s, v4.4s, v26.4s | |||||
| and v10.16b, v27.16b, v5.16b | |||||
| sshr v10.4s, v10.4s, #31 | |||||
| sqadd v5.4s, v5.4s, v10.4s | |||||
| srshl v5.4s, v5.4s, v27.4s | |||||
| and v9.16b, v26.16b, v6.16b | |||||
| sshr v9.4s, v9.4s, #31 | |||||
| sqadd v6.4s, v6.4s, v9.4s | |||||
| srshl v6.4s, v6.4s, v26.4s | |||||
| and v8.16b, v27.16b, v7.16b | |||||
| sshr v8.4s, v8.4s, #31 | |||||
| sqadd v7.4s, v7.4s, v8.4s | |||||
| srshl v7.4s, v7.4s, v27.4s | |||||
| add v0.4s, v0.4s, v20.4s | |||||
| add v1.4s, v1.4s, v21.4s | |||||
| add v2.4s, v2.4s, v20.4s | |||||
| add v3.4s, v3.4s, v21.4s | |||||
| add v4.4s, v4.4s, v20.4s | |||||
| add v5.4s, v5.4s, v21.4s | |||||
| add v6.4s, v6.4s, v20.4s | |||||
| add v7.4s, v7.4s, v21.4s | |||||
| smax v0.4s, v0.4s, v28.4s | |||||
| smax v1.4s, v1.4s, v29.4s | |||||
| smax v2.4s, v2.4s, v28.4s | |||||
| smax v3.4s, v3.4s, v29.4s | |||||
| smax v4.4s, v4.4s, v28.4s | |||||
| smax v5.4s, v5.4s, v29.4s | |||||
| smax v6.4s, v6.4s, v28.4s | |||||
| smax v7.4s, v7.4s, v29.4s | |||||
| smin v0.4s, v0.4s, v30.4s | |||||
| smin v1.4s, v1.4s, v31.4s | smin v1.4s, v1.4s, v31.4s | ||||
| smin v2.4s, v2.4s, v31.4s | |||||
| smin v2.4s, v2.4s, v30.4s | |||||
| smin v3.4s, v3.4s, v31.4s | smin v3.4s, v3.4s, v31.4s | ||||
| smin v4.4s, v4.4s, v31.4s | |||||
| smin v4.4s, v4.4s, v30.4s | |||||
| smin v5.4s, v5.4s, v31.4s | smin v5.4s, v5.4s, v31.4s | ||||
| smin v6.4s, v6.4s, v31.4s | |||||
| smin v6.4s, v6.4s, v30.4s | |||||
| smin v7.4s, v7.4s, v31.4s | smin v7.4s, v7.4s, v31.4s | ||||
| sqxtn v0.4h, v0.4s | sqxtn v0.4h, v0.4s | ||||
| @@ -535,93 +222,81 @@ ConvDwInt8Center: | |||||
| sqxtn v6.8b, v6.8h | sqxtn v6.8b, v6.8h | ||||
| sqxtn v7.8b, v7.8h | sqxtn v7.8b, v7.8h | ||||
| add x17, x3, #1 | |||||
| add x18, x3, #2 | |||||
| add x21, x3, #3 | |||||
| st1 {v0.b}[0], [x3], x9 | |||||
| st1 {v0.b}[1], [x17], x9 | |||||
| st1 {v0.b}[2], [x18], x9 | |||||
| st1 {v0.b}[3], [x21], x9 | |||||
| st1 {v1.b}[0], [x3], x9 | |||||
| st1 {v1.b}[1], [x17], x9 | |||||
| st1 {v1.b}[2], [x18], x9 | |||||
| st1 {v1.b}[3], [x21], x9 | |||||
| st1 {v2.b}[0], [x3], x9 | |||||
| st1 {v2.b}[1], [x17], x9 | |||||
| st1 {v2.b}[2], [x18], x9 | |||||
| st1 {v2.b}[3], [x21], x9 | |||||
| st1 {v3.b}[0], [x3], x9 | |||||
| st1 {v3.b}[1], [x17], x9 | |||||
| st1 {v3.b}[2], [x18], x9 | |||||
| st1 {v3.b}[3], [x21], x9 | |||||
| st1 {v4.b}[0], [x3], x9 | |||||
| st1 {v4.b}[1], [x17], x9 | |||||
| st1 {v4.b}[2], [x18], x9 | |||||
| st1 {v4.b}[3], [x21], x9 | |||||
| st1 {v5.b}[0], [x3], x9 | |||||
| st1 {v5.b}[1], [x17], x9 | |||||
| st1 {v5.b}[2], [x18], x9 | |||||
| st1 {v5.b}[3], [x21], x9 | |||||
| st1 {v6.b}[0], [x3], x9 | |||||
| st1 {v6.b}[1], [x17], x9 | |||||
| st1 {v6.b}[2], [x18], x9 | |||||
| st1 {v6.b}[3], [x21], x9 | |||||
| st1 {v7.b}[0], [x3], x9 | |||||
| st1 {v7.b}[1], [x17], x9 | |||||
| st1 {v7.b}[2], [x18], x9 | |||||
| st1 {v7.b}[3], [x21], x9 | |||||
| mov x16, x3 | |||||
| add x17, x16, x9 | |||||
| add x18, x17, x9 | |||||
| add x21, x18, x9 | |||||
| st1 {v0.s}[0], [x16], #4 | |||||
| st1 {v1.s}[0], [x16], #4 | |||||
| st1 {v2.s}[0], [x17], #4 | |||||
| st1 {v3.s}[0], [x17], #4 | |||||
| st1 {v4.s}[0], [x18], #4 | |||||
| st1 {v5.s}[0], [x18], #4 | |||||
| st1 {v6.s}[0], [x21], #4 | |||||
| st1 {v7.s}[0], [x21], #4 | |||||
| add x3, x3, x25 | |||||
| add x23, x23, x19 | add x23, x23, x19 | ||||
| sub x24, x24, #8 | |||||
| sub x24, x24, #4 | |||||
| cmp x24, #0 | cmp x24, #0 | ||||
| ble LoopWEnd | ble LoopWEnd | ||||
| cmp x24, #8 | |||||
| bge LoopW8 | |||||
| cmp x24, #4 | |||||
| bge LoopW4 | |||||
| LoopW: | LoopW: | ||||
| mov x16, x23 | mov x16, x23 | ||||
| mov x17, x2 | mov x17, x2 | ||||
| mov x20, x6 | mov x20, x6 | ||||
| mov v0.16b, v24.16b | |||||
| mov v0.16b, v17.16b | |||||
| mov v1.16b, v18.16b | |||||
| LoopKh: | LoopKh: | ||||
| mov x18, x7 | mov x18, x7 | ||||
| mov x22, x16 | mov x22, x16 | ||||
| LoopKw: | LoopKw: | ||||
| ld1 {v16.4h}, [x22], x13 | |||||
| ld1 {v25.4h}, [x17], #8 | |||||
| smlal v0.4s, v16.4h, v25.4h | |||||
| ld1 {v15.8b}, [x22], x13 | |||||
| ssubl v14.8h, v15.8b, v19.8b | |||||
| ld1 {v16.8h}, [x17], #16 | |||||
| smlal v0.4s, v14.4h, v16.4h | |||||
| smlal2 v1.4s, v14.8h, v16.8h | |||||
| subs x18, x18, #1 | subs x18, x18, #1 | ||||
| bne LoopKw | bne LoopKw | ||||
| add x16, x16, x12 | add x16, x16, x12 | ||||
| subs x20, x20, #1 | subs x20, x20, #1 | ||||
| bne LoopKh | bne LoopKh | ||||
| sqshl v0.4s, v0.4s, v26.4s | |||||
| sqrdmulh v0.4s, v0.4s, v27.4s | |||||
| sqshl v0.4s, v0.4s, v24.4s | |||||
| sqrdmulh v0.4s, v0.4s, v22.4s | |||||
| sqshl v1.4s, v1.4s, v25.4s | |||||
| sqrdmulh v1.4s, v1.4s, v23.4s | |||||
| and v16.16b, v28.16b, v0.16b | |||||
| sshr v16.4s, v16.4s, #31 | |||||
| sqadd v0.4s, v0.4s, v16.4s | |||||
| srshl v0.4s, v0.4s, v28.4s | |||||
| and v15.16b, v26.16b, v0.16b | |||||
| sshr v15.4s, v15.4s, #31 | |||||
| sqadd v0.4s, v0.4s, v15.4s | |||||
| srshl v0.4s, v0.4s, v26.4s | |||||
| add v0.4s, v0.4s, v29.4s | |||||
| smax v0.4s, v0.4s, v30.4s | |||||
| smin v0.4s, v0.4s, v31.4s | |||||
| and v14.16b, v27.16b, v1.16b | |||||
| sshr v14.4s, v14.4s, #31 | |||||
| sqadd v1.4s, v1.4s, v14.4s | |||||
| srshl v1.4s, v1.4s, v27.4s | |||||
| add v0.4s, v0.4s, v20.4s | |||||
| smax v0.4s, v0.4s, v28.4s | |||||
| smin v0.4s, v0.4s, v30.4s | |||||
| sqxtn v0.4h, v0.4s | sqxtn v0.4h, v0.4s | ||||
| sqxtn v0.8b, v0.8h | sqxtn v0.8b, v0.8h | ||||
| add v1.4s, v1.4s, v21.4s | |||||
| smax v1.4s, v1.4s, v29.4s | |||||
| smin v1.4s, v1.4s, v31.4s | |||||
| sqxtn v1.4h, v1.4s | |||||
| sqxtn v1.8b, v1.8h | |||||
| mov x17, x3 | mov x17, x3 | ||||
| st1 {v0.b}[0], [x17], #1 | |||||
| st1 {v0.b}[1], [x17], #1 | |||||
| st1 {v0.b}[2], [x17], #1 | |||||
| st1 {v0.b}[3], [x17], #1 | |||||
| st1 {v0.s}[0], [x17], #4 | |||||
| st1 {v1.s}[0], [x17], #4 | |||||
| add x3, x3, x9 | add x3, x3, x9 | ||||
| add x23, x23, x11 | add x23, x23, x11 | ||||
| @@ -45,10 +45,11 @@ void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *wei | |||||
| void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, | void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, | ||||
| size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | ||||
| size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); | size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); | ||||
| void ConvDwInt8Center(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, size_t height, | |||||
| void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, | |||||
| size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, | size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, | ||||
| size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int out_multiplier, | |||||
| int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max); | |||||
| size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp, | |||||
| int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, | |||||
| int32_t *acc_min, int32_t *acc_max); | |||||
| void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, | void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, | ||||
| int output_channel, int input_step, int8_t input_zp); | int output_channel, int input_step, int8_t input_zp); | ||||
| void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, | void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, | ||||
| @@ -138,75 +138,67 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da | |||||
| } | } | ||||
| /*conv depthwise int8 end*/ | /*conv depthwise int8 end*/ | ||||
| /*conv depthwise sliding window int8 begin*/ | |||||
| void DepthwiseBorderPixelInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height, | |||||
| int width, int in_kh_step, int in_kw_step, int kernel_w, int *out_multiplier, | |||||
| int *left_shift, int *right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max, | |||||
| bool per_channel) { | |||||
| int tmp_buffer[C4NUM]; | |||||
| for (int i = 0; i < C4NUM; i++) { | |||||
| /*conv depthwise sliding window perchannel int8 begin*/ | |||||
| void DepthwiseBorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, | |||||
| int width, int in_kh_step, int in_kw_step, int kernel_w, int8_t *input_zp, | |||||
| int32_t *out_zp, int *out_multiplier, int *left_shift, int *right_shift, int32_t *acc_min, | |||||
| int32_t *acc_max) { | |||||
| int tmp_buffer[C8NUM]; | |||||
| for (int i = 0; i < C8NUM; i++) { | |||||
| tmp_buffer[i] = 0; | tmp_buffer[i] = 0; | ||||
| } | } | ||||
| const int16_t *src_kh = src; | |||||
| const int8_t *src_kh = src; | |||||
| const int16_t *weight_kh = weight; | const int16_t *weight_kh = weight; | ||||
| for (int kh = 0; kh < height; kh++) { | for (int kh = 0; kh < height; kh++) { | ||||
| const int16_t *src_kw = src_kh; | |||||
| const int8_t *src_kw = src_kh; | |||||
| const int16_t *weight_kw = weight_kh; | const int16_t *weight_kw = weight_kh; | ||||
| for (int kw = 0; kw < width; kw++) { | for (int kw = 0; kw < width; kw++) { | ||||
| for (int c = 0; c < C4NUM; c++) { | |||||
| tmp_buffer[c] += src_kw[c] * weight_kw[c]; | |||||
| for (int c = 0; c < C8NUM; c++) { | |||||
| tmp_buffer[c] += (src_kw[c] - input_zp[c]) * weight_kw[c]; | |||||
| } | } | ||||
| src_kw += in_kw_step; | src_kw += in_kw_step; | ||||
| weight_kw += C4NUM; | |||||
| weight_kw += C8NUM; | |||||
| } // kernel_w loop | } // kernel_w loop | ||||
| src_kh += in_kh_step; | src_kh += in_kh_step; | ||||
| weight_kh += kernel_w * C4NUM; | |||||
| weight_kh += kernel_w * C8NUM; | |||||
| } // kernel_h loop | } // kernel_h loop | ||||
| int32_t left = left_shift[0]; | |||||
| int32_t right = right_shift[0]; | |||||
| int32_t multiplier = out_multiplier[0]; | |||||
| for (int c = 0; c < C4NUM; c++) { | |||||
| if (per_channel) { | |||||
| left = left_shift[c]; | |||||
| right = right_shift[c]; | |||||
| multiplier = out_multiplier[c]; | |||||
| } | |||||
| for (int c = 0; c < C8NUM; c++) { | |||||
| tmp_buffer[c] += bias[c]; | tmp_buffer[c] += bias[c]; | ||||
| tmp_buffer[c] = RoundingDivideByPOT( | tmp_buffer[c] = RoundingDivideByPOT( | ||||
| SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left), multiplier), -right); | |||||
| tmp_buffer[c] += out_zp; | |||||
| tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); | |||||
| tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); | |||||
| SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), | |||||
| -right_shift[c]); | |||||
| tmp_buffer[c] += out_zp[c]; | |||||
| tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min[c]); | |||||
| tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max[c]); | |||||
| dst[c] = (tmp_buffer[c]); | dst[c] = (tmp_buffer[c]); | ||||
| } | } | ||||
| } | } | ||||
| void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int top, | |||||
| void DepthwiseBorderInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int top, | |||||
| int bottom, int left, int right, const ConvParameter *conv_param, | int bottom, int left, int right, const ConvParameter *conv_param, | ||||
| const SlidingWindowParam *sliding, int *out_multiplier, int *left_shift, int *right_shift, | |||||
| bool per_channel) { | |||||
| const SlidingWindowParam *sliding, int8_t *in_zp, int32_t *out_zp, int *out_multiplier, | |||||
| int *left_shift, int *right_shift, int32_t *acc_min, int32_t *acc_max) { | |||||
| int8_t *dst_h = dst + top * sliding->out_h_step_; | int8_t *dst_h = dst + top * sliding->out_h_step_; | ||||
| for (int oh = top; oh < bottom; oh++) { | for (int oh = top; oh < bottom; oh++) { | ||||
| int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; | int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; | ||||
| int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); | int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); | ||||
| int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); | int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); | ||||
| const int16_t *src_h = src + ih * sliding->in_h_step_; | |||||
| const int8_t *src_h = src + ih * sliding->in_h_step_; | |||||
| int8_t *dst_kernel = dst_h + left * sliding->block_channel_; | int8_t *dst_kernel = dst_h + left * sliding->block_channel_; | ||||
| for (int ow = left; ow < right; ow++) { | for (int ow = left; ow < right; ow++) { | ||||
| int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; | int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; | ||||
| int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); | int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); | ||||
| int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); | int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); | ||||
| const int16_t *src_w = src_h + iw * sliding->block_channel_; | |||||
| const int8_t *src_w = src_h + iw * sliding->block_channel_; | |||||
| const int16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; | |||||
| const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; | |||||
| const int8_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; | |||||
| const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM; | |||||
| DepthwiseBorderPixelInt8(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, | DepthwiseBorderPixelInt8(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, | ||||
| sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, out_multiplier, | |||||
| left_shift, right_shift, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | |||||
| per_channel); | |||||
| sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, in_zp, out_zp, | |||||
| out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| dst_kernel += sliding->block_channel_; | dst_kernel += sliding->block_channel_; | ||||
| } // width loop | } // width loop | ||||
| @@ -215,52 +207,46 @@ void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight, | |||||
| } | } | ||||
| #ifndef ENABLE_ARM64 | #ifndef ENABLE_ARM64 | ||||
| void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height, | |||||
| void DepthwiseCenterInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, | |||||
| int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, | int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, | ||||
| int in_sw_step, int in_kh_step, int in_kw_step, int *out_multiplier, int *left_shift, | |||||
| int *right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max, bool per_channel) { | |||||
| int tmp_buffer[C4NUM]; | |||||
| int in_sw_step, int in_kh_step, int in_kw_step, int8_t *in_zp, int32_t *out_zp, | |||||
| int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t *acc_min, | |||||
| int32_t *acc_max) { | |||||
| int tmp_buffer[C8NUM]; | |||||
| int8_t *dst_h = dst; | int8_t *dst_h = dst; | ||||
| const int16_t *src_h = src; | |||||
| const int8_t *src_h = src; | |||||
| for (int oh = 0; oh < height; oh++) { | for (int oh = 0; oh < height; oh++) { | ||||
| int8_t *dst_w = dst_h; | int8_t *dst_w = dst_h; | ||||
| const int16_t *src_w = src_h; | |||||
| const int8_t *src_w = src_h; | |||||
| for (int ow = 0; ow < width; ow++) { | for (int ow = 0; ow < width; ow++) { | ||||
| const int16_t *src_kh = src_w; | |||||
| const int8_t *src_kh = src_w; | |||||
| const int16_t *weight_kh = weight; | const int16_t *weight_kh = weight; | ||||
| for (int i = 0; i < C4NUM; i++) { | |||||
| for (int i = 0; i < C8NUM; i++) { | |||||
| tmp_buffer[i] = 0; | tmp_buffer[i] = 0; | ||||
| } | } | ||||
| for (int kh = 0; kh < kernel_h; kh++) { | for (int kh = 0; kh < kernel_h; kh++) { | ||||
| const int16_t *src_kw = src_kh; | |||||
| const int8_t *src_kw = src_kh; | |||||
| const int16_t *weight_kw = weight_kh; | const int16_t *weight_kw = weight_kh; | ||||
| for (int kw = 0; kw < kernel_w; kw++) { | for (int kw = 0; kw < kernel_w; kw++) { | ||||
| for (int c = 0; c < C4NUM; c++) { | |||||
| tmp_buffer[c] += src_kw[c] * weight_kw[c]; | |||||
| for (int c = 0; c < C8NUM; c++) { | |||||
| tmp_buffer[c] += (src_kw[c] - in_zp[c]) * weight_kw[c]; | |||||
| } | } | ||||
| src_kw += in_kw_step; | src_kw += in_kw_step; | ||||
| weight_kw += C4NUM; | |||||
| weight_kw += C8NUM; | |||||
| } // kernel_w loop | } // kernel_w loop | ||||
| src_kh += in_kh_step; | src_kh += in_kh_step; | ||||
| weight_kh += kernel_w * C4NUM; | |||||
| weight_kh += kernel_w * C8NUM; | |||||
| } // kernel_h loop | } // kernel_h loop | ||||
| // add bias relu | // add bias relu | ||||
| int32_t left = left_shift[0]; | |||||
| int32_t right = right_shift[0]; | |||||
| int32_t multiplier = out_multiplier[0]; | |||||
| for (int c = 0; c < C4NUM; c++) { | |||||
| if (per_channel) { | |||||
| left = left_shift[c]; | |||||
| right = right_shift[c]; | |||||
| multiplier = out_multiplier[c]; | |||||
| } | |||||
| for (int c = 0; c < C8NUM; c++) { | |||||
| tmp_buffer[c] += bias[c]; | tmp_buffer[c] += bias[c]; | ||||
| tmp_buffer[c] = RoundingDivideByPOT( | tmp_buffer[c] = RoundingDivideByPOT( | ||||
| SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left), multiplier), -right); | |||||
| tmp_buffer[c] += out_zp; | |||||
| tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); | |||||
| tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); | |||||
| SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), | |||||
| -right_shift[c]); | |||||
| tmp_buffer[c] += out_zp[c]; | |||||
| tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min[c]); | |||||
| tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max[c]); | |||||
| dst_w[c] = (tmp_buffer[c]); | dst_w[c] = (tmp_buffer[c]); | ||||
| } | } | ||||
| dst_w += block_channel; | dst_w += block_channel; | ||||
| @@ -272,69 +258,65 @@ void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight, | |||||
| } | } | ||||
| #endif | #endif | ||||
| void ConvDwSWInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data, | |||||
| const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) { | |||||
| const int16_t *src = input_data; | |||||
| void ConvDwSWInt8(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data, | |||||
| int8_t *input_zp, int32_t *output_zp, const ConvParameter *conv_param, | |||||
| const SlidingWindowParam *sliding, int task_id) { | |||||
| const int8_t *src = input_data; | |||||
| int8_t *dst = output_data; | int8_t *dst = output_data; | ||||
| bool per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; | |||||
| int *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; | |||||
| int *left_shift = conv_param->conv_quant_arg_.left_shift_; | |||||
| int *right_shift = conv_param->conv_quant_arg_.right_shift_; | |||||
| for (int b = 0; b < conv_param->output_batch_; b++) { | for (int b = 0; b < conv_param->output_batch_; b++) { | ||||
| for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { | for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { | ||||
| const int16_t *src_data = src + oc * C4NUM; | |||||
| int8_t *dst_data = dst + oc * C4NUM; | |||||
| const int8_t *src_data = src + oc * C8NUM; | |||||
| int8_t *dst_data = dst + oc * C8NUM; | |||||
| const int16_t *weight = weight_data + oc * sliding->kernel_step_; | const int16_t *weight = weight_data + oc * sliding->kernel_step_; | ||||
| const int32_t *bias = bias_data + oc * C4NUM; | |||||
| const int32_t *bias = bias_data + oc * C8NUM; | |||||
| if (per_channel) { | |||||
| out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_ + oc * C4NUM; | |||||
| left_shift = conv_param->conv_quant_arg_.left_shift_ + oc * C4NUM; | |||||
| right_shift = conv_param->conv_quant_arg_.right_shift_ + oc * C4NUM; | |||||
| } | |||||
| int *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_ + oc * C8NUM; | |||||
| int *left_shift = conv_param->conv_quant_arg_.left_shift_ + oc * C8NUM; | |||||
| int *right_shift = conv_param->conv_quant_arg_.right_shift_ + oc * C8NUM; | |||||
| int *acc_min = conv_param->conv_quant_arg_.out_act_min_ + oc * C8NUM; | |||||
| int *acc_max = conv_param->conv_quant_arg_.out_act_max_ + oc * C8NUM; | |||||
| int8_t *in_zp = input_zp + oc * C8NUM; | |||||
| int32_t *out_zp = output_zp + oc * C8NUM; | |||||
| DepthwiseBorderInt8(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, | DepthwiseBorderInt8(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, | ||||
| sliding, out_multiplier, left_shift, right_shift, per_channel); | |||||
| sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, | DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, | ||||
| conv_param->output_w_, conv_param, sliding, out_multiplier, left_shift, right_shift, | |||||
| per_channel); | |||||
| conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, | |||||
| right_shift, acc_min, acc_max); | |||||
| DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, | DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, | ||||
| conv_param, sliding, out_multiplier, left_shift, right_shift, per_channel); | |||||
| conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, | |||||
| acc_max); | |||||
| DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, | DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, | ||||
| conv_param->output_w_, conv_param, sliding, out_multiplier, left_shift, right_shift, | |||||
| per_channel); | |||||
| conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, | |||||
| right_shift, acc_min, acc_max); | |||||
| if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { | if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { | ||||
| int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; | int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; | ||||
| int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; | int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; | ||||
| const int16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; | |||||
| const int8_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; | |||||
| int8_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; | int8_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; | ||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| ConvDwInt8Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, | ConvDwInt8Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, | ||||
| conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(int8_t), | conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(int8_t), | ||||
| sliding->block_channel_ * sizeof(int8_t), sliding->in_sh_step_ * sizeof(int16_t), | |||||
| sliding->in_sw_step_ * sizeof(int16_t), sliding->in_kh_step_ * sizeof(int16_t), | |||||
| sliding->in_kw_step_ * sizeof(int16_t), conv_param->conv_quant_arg_.quant_multiplier_[0], | |||||
| conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0], | |||||
| conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); | |||||
| sliding->block_channel_ * sizeof(int8_t), sliding->in_sh_step_ * sizeof(int8_t), | |||||
| sliding->in_sw_step_ * sizeof(int8_t), sliding->in_kh_step_ * sizeof(int8_t), | |||||
| sliding->in_kw_step_ * sizeof(int8_t), in_zp, out_zp, out_multiplier, left_shift, right_shift, | |||||
| acc_min, acc_max); | |||||
| #else | #else | ||||
| DepthwiseCenterInt8( | |||||
| out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, | |||||
| conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, | |||||
| sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, out_multiplier, | |||||
| left_shift, right_shift, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], per_channel); | |||||
| DepthwiseCenterInt8(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, | |||||
| sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, | |||||
| sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_, | |||||
| sliding->in_kh_step_, sliding->in_kw_step_, in_zp, out_zp, out_multiplier, left_shift, | |||||
| right_shift, acc_min, acc_max); | |||||
| #endif | #endif | ||||
| } | } | ||||
| } // output C4 loop | |||||
| } // output C8 loop | |||||
| src += sliding->in_step_; | src += sliding->in_step_; | ||||
| dst += sliding->out_step_; | dst += sliding->out_step_; | ||||
| } // batch loop | } // batch loop | ||||
| // output nhwc4 | |||||
| // output nhwc8 | |||||
| } | } | ||||
| /*conv depthwise sliding window int8 end*/ | |||||
| /*conv depthwise sliding window perchannel int8 end*/ | |||||
| /*deconv depthwise int8 begin*/ | /*deconv depthwise int8 begin*/ | ||||
| void DeconvDepthwiseBorderPixelInt8(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width, | void DeconvDepthwiseBorderPixelInt8(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width, | ||||
| @@ -27,8 +27,9 @@ extern "C" { | |||||
| void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data, | void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data, | ||||
| const int32_t *bias_data, const ConvParameter *conv_param, int task_id); | const int32_t *bias_data, const ConvParameter *conv_param, int task_id); | ||||
| void ConvDwSWInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data, | |||||
| const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); | |||||
| void ConvDwSWInt8(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data, | |||||
| int8_t *input_zp, int32_t *output_zp, const ConvParameter *conv_param, | |||||
| const SlidingWindowParam *sliding, int task_id); | |||||
| void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data, | void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data, | ||||
| const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, | const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, | ||||
| @@ -965,6 +965,45 @@ void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int c | |||||
| } | } | ||||
| } | } | ||||
| void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel) { | |||||
| int c8 = UP_DIV(channel, C8NUM); | |||||
| int nhwc8_batch_unit_offset = c8 * C8NUM * plane; | |||||
| int ic_remainder_ = channel % C8NUM; | |||||
| if (ic_remainder_ != 0) { | |||||
| int nhwc8_batch_offset = 0; | |||||
| for (int b = 0; b < batch; b++) { | |||||
| int batch_offset = b * channel * plane; | |||||
| for (int i = 0; i < plane; i++) { | |||||
| memcpy((int8_t *)dst + nhwc8_batch_offset + i * c8 * C8NUM, (int8_t *)src + batch_offset + i * channel, | |||||
| channel); | |||||
| } | |||||
| nhwc8_batch_offset += nhwc8_batch_unit_offset; | |||||
| } | |||||
| } else { | |||||
| size_t ori_input_size = batch * plane * channel; | |||||
| memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); | |||||
| } | |||||
| } | |||||
| void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { | |||||
| int c8 = UP_DIV(channel, C8NUM); | |||||
| int nhwc8_batch_unit_offset = c8 * C8NUM * plane; | |||||
| int ic_remainder_ = channel % C8NUM; | |||||
| if (ic_remainder_ != 0) { | |||||
| for (int b = 0; b < batch; b++) { | |||||
| int batch_offset = b * channel * plane; | |||||
| int nhwc8_batch_offset = b * nhwc8_batch_unit_offset; | |||||
| for (int i = 0; i < plane; i++) { | |||||
| memcpy((int8_t *)dst + batch_offset + i * channel, (int8_t *)src + nhwc8_batch_offset + i * c8 * C8NUM, | |||||
| channel); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| size_t ori_input_size = batch * plane * channel; | |||||
| memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); | |||||
| } | |||||
| } | |||||
| void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) { | void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) { | ||||
| int nhwc4_batch_offset = 0; | int nhwc4_batch_offset = 0; | ||||
| int c4 = UP_DIV(channel, C4NUM); | int c4 = UP_DIV(channel, C4NUM); | ||||
| @@ -1270,6 +1309,25 @@ void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter | |||||
| void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, | void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, | ||||
| ConvQuantArg *quant_qrg) { | ConvQuantArg *quant_qrg) { | ||||
| int weight_zp = quant_qrg->filter_quant_args_[0].zp_; | int weight_zp = quant_qrg->filter_quant_args_[0].zp_; | ||||
| for (int c = 0; c < channel; c++) { | |||||
| if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) { | |||||
| weight_zp = quant_qrg->filter_quant_args_[c].zp_; | |||||
| } | |||||
| int c8_block_num = c / C8NUM; | |||||
| int c8_block_rem = c % C8NUM; | |||||
| const int8_t *src_c = origin_weight + c * plane; | |||||
| int16_t *dst_c = packed_weight_ + c8_block_num * plane * C8NUM; | |||||
| for (int k = 0; k < plane; k++) { | |||||
| const int8_t *src_kernel = src_c + k; | |||||
| int16_t *dst_kernel = dst_c + C8NUM * k + c8_block_rem; | |||||
| *dst_kernel = (int16_t)(src_kernel[0] - weight_zp); | |||||
| } | |||||
| } | |||||
| } | |||||
| void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, | |||||
| ConvQuantArg *quant_qrg) { | |||||
| int weight_zp = quant_qrg->filter_quant_args_[0].zp_; | |||||
| for (int c = 0; c < channel; c++) { | for (int c = 0; c < channel; c++) { | ||||
| if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) { | if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) { | ||||
| weight_zp = quant_qrg->filter_quant_args_[c].zp_; | weight_zp = quant_qrg->filter_quant_args_[c].zp_; | ||||
| @@ -96,6 +96,10 @@ void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int c | |||||
| void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); | void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); | ||||
| void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel); | |||||
| void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); | |||||
| void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); | void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); | ||||
| void PackNC4HW4ToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); | void PackNC4HW4ToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); | ||||
| @@ -114,6 +118,9 @@ void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter | |||||
| void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, | void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, | ||||
| ConvQuantArg *quant_qrg); | ConvQuantArg *quant_qrg); | ||||
| void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, | |||||
| ConvQuantArg *quant_qrg); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -177,8 +177,17 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::Tensor *> | |||||
| const mindspore::lite::PrimitiveC *primitive) { | const mindspore::lite::PrimitiveC *primitive) { | ||||
| MS_ASSERT(opParameter != nullptr); | MS_ASSERT(opParameter != nullptr); | ||||
| MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); | MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); | ||||
| auto kernel = | |||||
| new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| kernel::LiteKernel *kernel; | |||||
| auto act_quant_size = | |||||
| MSMAX(inputs[kInputIndex]->GetQuantParams().size(), outputs[kOutputIndex]->GetQuantParams().size()); | |||||
| if (act_quant_size == 1) { // per tensor | |||||
| kernel = new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| } else { // per channel | |||||
| kernel = | |||||
| new (std::nothrow) kernel::ConvolutionDepthwiseSWInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| } | |||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -37,6 +37,7 @@ ConvolutionDepthwiseSWInt8CPUKernel::~ConvolutionDepthwiseSWInt8CPUKernel() { | |||||
| free(packed_weight_); | free(packed_weight_); | ||||
| packed_weight_ = nullptr; | packed_weight_ = nullptr; | ||||
| } | } | ||||
| FreeTmpQuant(); | |||||
| FreeQuantParam(); | FreeQuantParam(); | ||||
| } | } | ||||
| @@ -45,8 +46,8 @@ int ConvolutionDepthwiseSWInt8CPUKernel::InitWeightBias() { | |||||
| // o, h, w, i -> o/8, h, w, i, 8; o == group, i == 1 | // o, h, w, i -> o/8, h, w, i, 8; o == group, i == 1 | ||||
| auto weight_tensor = in_tensors_[kWeightIndex]; | auto weight_tensor = in_tensors_[kWeightIndex]; | ||||
| auto origin_weight = reinterpret_cast<int8_t *>(weight_tensor->MutableData()); | auto origin_weight = reinterpret_cast<int8_t *>(weight_tensor->MutableData()); | ||||
| int OC4 = UP_DIV(weight_tensor->Batch(), C4NUM); | |||||
| int pack_weight_size = C4NUM * OC4 * weight_tensor->Height() * weight_tensor->Width(); | |||||
| int OC8 = UP_DIV(weight_tensor->Batch(), C8NUM); | |||||
| int pack_weight_size = C8NUM * OC8 * weight_tensor->Height() * weight_tensor->Width(); | |||||
| packed_weight_ = reinterpret_cast<int16_t *>(malloc(pack_weight_size * sizeof(int16_t))); | packed_weight_ = reinterpret_cast<int16_t *>(malloc(pack_weight_size * sizeof(int16_t))); | ||||
| if (packed_weight_ == nullptr) { | if (packed_weight_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Malloc buffer failed."; | MS_LOG(ERROR) << "Malloc buffer failed."; | ||||
| @@ -55,35 +56,36 @@ int ConvolutionDepthwiseSWInt8CPUKernel::InitWeightBias() { | |||||
| PackDepthwiseInt8Weight(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), | PackDepthwiseInt8Weight(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), | ||||
| weight_tensor->Batch(), &(conv_param_->conv_quant_arg_)); | weight_tensor->Batch(), &(conv_param_->conv_quant_arg_)); | ||||
| bias_data_ = reinterpret_cast<int32_t *>(malloc(C4NUM * OC4 * sizeof(int32_t))); | |||||
| bias_data_ = reinterpret_cast<int32_t *>(malloc(C8NUM * OC8 * sizeof(int32_t))); | |||||
| if (bias_data_ == nullptr) { | if (bias_data_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Malloc buffer failed."; | MS_LOG(ERROR) << "Malloc buffer failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memset(bias_data_, 0, C4NUM * OC4 * sizeof(int32_t)); | |||||
| memset(bias_data_, 0, C8NUM * OC8 * sizeof(int32_t)); | |||||
| if (in_tensors_.size() == kInputSize2) { | if (in_tensors_.size() == kInputSize2) { | ||||
| auto bias_tensor = in_tensors_.at(kBiasIndex); | auto bias_tensor = in_tensors_.at(kBiasIndex); | ||||
| auto ori_bias = reinterpret_cast<int32_t *>(bias_tensor->MutableData()); | auto ori_bias = reinterpret_cast<int32_t *>(bias_tensor->MutableData()); | ||||
| memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(int32_t)); | memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(int32_t)); | ||||
| } | } | ||||
| conv_param_->thread_num_ = MSMIN(thread_count_, OC4); | |||||
| conv_param_->thread_num_ = MSMIN(thread_count_, OC8); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ConvolutionDepthwiseSWInt8CPUKernel::InitBuffer() { | int ConvolutionDepthwiseSWInt8CPUKernel::InitBuffer() { | ||||
| int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * | |||||
| UP_DIV(conv_param_->input_channel_, 4); | |||||
| packed_input_ = reinterpret_cast<int16_t *>(context_->allocator->Malloc(pack_input_size * sizeof(int16_t))); | |||||
| if (packed_input_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (conv_param_->input_channel_ % C4NUM != 0) { | |||||
| if (conv_param_->input_channel_ % C8NUM != 0) { | |||||
| need_align_ = true; | need_align_ = true; | ||||
| int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * | |||||
| UP_DIV(conv_param_->output_channel_, C4NUM); | |||||
| int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C8NUM * | |||||
| UP_DIV(conv_param_->input_channel_, C8NUM); | |||||
| packed_input_ = reinterpret_cast<int8_t *>(context_->allocator->Malloc(pack_input_size * sizeof(int8_t))); | |||||
| if (packed_input_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C8NUM * | |||||
| UP_DIV(conv_param_->output_channel_, C8NUM); | |||||
| packed_output_ = reinterpret_cast<int8_t *>(context_->allocator->Malloc(pack_output_size * sizeof(int8_t))); | packed_output_ = reinterpret_cast<int8_t *>(context_->allocator->Malloc(pack_output_size * sizeof(int8_t))); | ||||
| if (packed_input_ == nullptr) { | if (packed_input_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Malloc buffer failed."; | MS_LOG(ERROR) << "Malloc buffer failed."; | ||||
| @@ -93,6 +95,136 @@ int ConvolutionDepthwiseSWInt8CPUKernel::InitBuffer() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void ConvolutionDepthwiseSWInt8CPUKernel::FreeTmpQuant() { | |||||
| if (input_scale_ != nullptr) { | |||||
| free(input_scale_); | |||||
| input_scale_ = nullptr; | |||||
| } | |||||
| if (input_zp_ != nullptr) { | |||||
| free(input_zp_); | |||||
| input_zp_ = nullptr; | |||||
| } | |||||
| if (weight_scale_ != nullptr) { | |||||
| free(weight_scale_); | |||||
| weight_scale_ = nullptr; | |||||
| } | |||||
| if (output_scale_ != nullptr) { | |||||
| free(output_scale_); | |||||
| output_scale_ = nullptr; | |||||
| } | |||||
| if (output_zp_ != nullptr) { | |||||
| free(output_zp_); | |||||
| output_zp_ = nullptr; | |||||
| } | |||||
| } | |||||
| int ConvolutionDepthwiseSWInt8CPUKernel::ReinitFreeBefore() { | |||||
| FreeTmpQuant(); | |||||
| if (conv_quant_arg_->real_multiplier_ != nullptr) { | |||||
| free(conv_quant_arg_->real_multiplier_); | |||||
| conv_quant_arg_->real_multiplier_ = nullptr; | |||||
| } | |||||
| if (conv_quant_arg_->left_shift_ != nullptr) { | |||||
| free(conv_quant_arg_->left_shift_); | |||||
| conv_quant_arg_->left_shift_ = nullptr; | |||||
| } | |||||
| if (conv_quant_arg_->right_shift_ != nullptr) { | |||||
| free(conv_quant_arg_->right_shift_); | |||||
| conv_quant_arg_->right_shift_ = nullptr; | |||||
| } | |||||
| if (conv_quant_arg_->quant_multiplier_ != nullptr) { | |||||
| free(conv_quant_arg_->quant_multiplier_); | |||||
| conv_quant_arg_->quant_multiplier_ = nullptr; | |||||
| } | |||||
| if (conv_quant_arg_->out_act_min_ != nullptr) { | |||||
| free(conv_quant_arg_->out_act_min_); | |||||
| conv_quant_arg_->out_act_min_ = nullptr; | |||||
| } | |||||
| if (conv_quant_arg_->out_act_max_ != nullptr) { | |||||
| free(conv_quant_arg_->out_act_max_); | |||||
| conv_quant_arg_->out_act_max_ = nullptr; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionDepthwiseSWInt8CPUKernel::ReinitQuantParam() { | |||||
| ReinitFreeBefore(); // remalloc quant param buffer | |||||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||||
| auto channel = conv_param_->input_channel_; | |||||
| input_scale_ = reinterpret_cast<float *>(malloc(channel * sizeof(float))); | |||||
| input_zp_ = reinterpret_cast<int8_t *>(malloc(channel * sizeof(int8_t))); | |||||
| if (input_tensor->GetQuantParams().size() == kPerTensor) { | |||||
| for (int i = 0; i < channel; i++) { | |||||
| auto input_quant_arg = input_tensor->GetQuantParams().front(); | |||||
| input_zp_[i] = input_quant_arg.zeroPoint; | |||||
| input_scale_[i] = input_quant_arg.scale; | |||||
| } | |||||
| } else { | |||||
| for (int i = 0; i < channel; i++) { | |||||
| auto input_quant_arg = input_tensor->GetQuantParams()[i]; | |||||
| input_zp_[i] = input_quant_arg.zeroPoint; | |||||
| input_scale_[i] = input_quant_arg.scale; | |||||
| } | |||||
| } | |||||
| auto output_tensor = out_tensors_.at(kOutputIndex); | |||||
| output_scale_ = reinterpret_cast<float *>(malloc(channel * sizeof(float))); | |||||
| output_zp_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t))); | |||||
| if (output_tensor->GetQuantParams().size() == kPerTensor) { | |||||
| for (int i = 0; i < channel; i++) { | |||||
| auto output_quant_arg = output_tensor->GetQuantParams().front(); | |||||
| output_zp_[i] = output_quant_arg.zeroPoint; | |||||
| output_scale_[i] = output_quant_arg.scale; | |||||
| } | |||||
| } else { | |||||
| for (int i = 0; i < channel; i++) { | |||||
| auto output_quant_arg = output_tensor->GetQuantParams()[i]; | |||||
| output_zp_[i] = output_quant_arg.zeroPoint; | |||||
| output_scale_[i] = output_quant_arg.scale; | |||||
| } | |||||
| } | |||||
| conv_quant_arg_->real_multiplier_ = reinterpret_cast<double *>(malloc(channel * sizeof(double))); | |||||
| conv_quant_arg_->left_shift_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t))); | |||||
| conv_quant_arg_->right_shift_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t))); | |||||
| conv_quant_arg_->quant_multiplier_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t))); | |||||
| conv_quant_arg_->out_act_min_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t))); | |||||
| conv_quant_arg_->out_act_max_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t))); | |||||
| weight_scale_ = reinterpret_cast<float *>(malloc(channel * sizeof(float))); | |||||
| auto weight_tensor = in_tensors_.at(kWeightIndex); | |||||
| if (weight_tensor->GetQuantParams().size() == kPerTensor) { | |||||
| for (int i = 0; i < channel; i++) { | |||||
| auto weight_quant_arg = weight_tensor->GetQuantParams().front(); | |||||
| weight_scale_[i] = weight_quant_arg.scale; | |||||
| } | |||||
| } else { | |||||
| for (int i = 0; i < channel; i++) { | |||||
| auto weight_quant_arg = weight_tensor->GetQuantParams()[i]; | |||||
| weight_scale_[i] = weight_quant_arg.scale; | |||||
| } | |||||
| } | |||||
| for (int i = 0; i < channel; ++i) { | |||||
| const double in_scale = static_cast<double>(input_scale_[i] * weight_scale_[i]); | |||||
| double real_multiplier = in_scale / static_cast<double>(output_scale_[i]); | |||||
| conv_quant_arg_->real_multiplier_[i] = real_multiplier; | |||||
| QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], &conv_quant_arg_->left_shift_[i], | |||||
| &conv_quant_arg_->right_shift_[i]); | |||||
| } | |||||
| // now only consider per tensor for output | |||||
| bool relu = conv_param_->act_type_ == ActType_Relu; | |||||
| bool relu6 = conv_param_->act_type_ == ActType_Relu6; | |||||
| for (int i = 0; i < channel; ++i) { | |||||
| CalculateActivationRangeQuantized(relu, relu6, output_zp_[i], output_scale_[i], | |||||
| &conv_param_->conv_quant_arg_.out_act_min_[i], | |||||
| &conv_param_->conv_quant_arg_.out_act_max_[i]); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionDepthwiseSWInt8CPUKernel::Init() { | int ConvolutionDepthwiseSWInt8CPUKernel::Init() { | ||||
| sliding = new (std::nothrow) SlidingWindowParam; | sliding = new (std::nothrow) SlidingWindowParam; | ||||
| if (sliding == nullptr) { | if (sliding == nullptr) { | ||||
| @@ -107,13 +239,19 @@ int ConvolutionDepthwiseSWInt8CPUKernel::Init() { | |||||
| int ConvolutionDepthwiseSWInt8CPUKernel::ReSize() { | int ConvolutionDepthwiseSWInt8CPUKernel::ReSize() { | ||||
| ConvolutionBaseCPUKernel::Init(); | ConvolutionBaseCPUKernel::Init(); | ||||
| InitSlidingParamConvDw(sliding, conv_param_, C4NUM); | |||||
| InitSlidingParamConvDw(sliding, conv_param_, C8NUM); | |||||
| auto ret = ConvolutionBaseCPUKernel::SetQuantParam(); | auto ret = ConvolutionBaseCPUKernel::SetQuantParam(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Set quant param failed."; | MS_LOG(ERROR) << "Set quant param failed."; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| ret = ReinitQuantParam(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "reinit quant param failed."; | |||||
| return ret; | |||||
| } | |||||
| ret = InitWeightBias(); | ret = InitWeightBias(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Depthwise int8 InitWeightBias error!"; | MS_LOG(ERROR) << "Depthwise int8 InitWeightBias error!"; | ||||
| @@ -123,8 +261,8 @@ int ConvolutionDepthwiseSWInt8CPUKernel::ReSize() { | |||||
| } | } | ||||
| int ConvolutionDepthwiseSWInt8CPUKernel::Execute(int task_id) { | int ConvolutionDepthwiseSWInt8CPUKernel::Execute(int task_id) { | ||||
| ConvDwSWInt8(packed_output_, packed_input_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), conv_param_, | |||||
| sliding, task_id); | |||||
| ConvDwSWInt8(packed_output_, packed_input_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), input_zp_, | |||||
| output_zp_, conv_param_, sliding, task_id); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -157,7 +295,12 @@ int ConvolutionDepthwiseSWInt8CPUKernel::Run() { | |||||
| auto input_tensor = in_tensors_.at(kInputIndex); | auto input_tensor = in_tensors_.at(kInputIndex); | ||||
| auto input_addr = reinterpret_cast<int8_t *>(input_tensor->MutableData()); | auto input_addr = reinterpret_cast<int8_t *>(input_tensor->MutableData()); | ||||
| PackDepthwiseInt8Input(input_addr, packed_input_, conv_param_); | |||||
| if (need_align_) { | |||||
| PackNHWCToNHWC8Int8(input_addr, packed_input_, conv_param_->output_batch_, | |||||
| conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); | |||||
| } else { | |||||
| packed_input_ = input_addr; | |||||
| } | |||||
| auto output_addr = reinterpret_cast<int8_t *>(out_tensors_.at(kOutputIndex)->MutableData()); | auto output_addr = reinterpret_cast<int8_t *>(out_tensors_.at(kOutputIndex)->MutableData()); | ||||
| if (!need_align_) { | if (!need_align_) { | ||||
| @@ -171,11 +314,11 @@ int ConvolutionDepthwiseSWInt8CPUKernel::Run() { | |||||
| } | } | ||||
| if (need_align_) { | if (need_align_) { | ||||
| PackNHWC4ToNHWCInt8(packed_output_, output_addr, conv_param_->output_batch_, | |||||
| PackNHWC8ToNHWCInt8(packed_output_, output_addr, conv_param_->output_batch_, | |||||
| conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); | conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); | ||||
| context_->allocator->Free(packed_input_); | |||||
| context_->allocator->Free(packed_output_); | context_->allocator->Free(packed_output_); | ||||
| } | } | ||||
| context_->allocator->Free(packed_input_); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -40,11 +40,21 @@ class ConvolutionDepthwiseSWInt8CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| int Execute(int task_id); | int Execute(int task_id); | ||||
| private: | private: | ||||
| int ReinitQuantParam(); | |||||
| int ReinitFreeBefore(); | |||||
| void FreeTmpQuant(); | |||||
| SlidingWindowParam *sliding = nullptr; | SlidingWindowParam *sliding = nullptr; | ||||
| int16_t *packed_weight_ = nullptr; | int16_t *packed_weight_ = nullptr; | ||||
| int16_t *packed_input_ = nullptr; | |||||
| int8_t *packed_input_ = nullptr; | |||||
| int8_t *packed_output_ = nullptr; | int8_t *packed_output_ = nullptr; | ||||
| bool need_align_ = false; | bool need_align_ = false; | ||||
| int8_t *input_zp_ = nullptr; | |||||
| float *input_scale_ = nullptr; | |||||
| float *weight_scale_ = nullptr; | |||||
| int32_t *output_zp_ = nullptr; | |||||
| float *output_scale_ = nullptr; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -52,8 +52,8 @@ int DeconvolutionDepthwiseInt8CPUKernel::InitWeightBias() { | |||||
| MS_LOG(ERROR) << "Malloc buffer failed."; | MS_LOG(ERROR) << "Malloc buffer failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| PackDepthwiseInt8Weight(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), | |||||
| weight_tensor->Batch(), &(conv_param_->conv_quant_arg_)); | |||||
| PackDeconvDepthwiseInt8Weight(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), | |||||
| weight_tensor->Batch(), &(conv_param_->conv_quant_arg_)); | |||||
| bias_data_ = reinterpret_cast<int32_t *>(malloc(C4NUM * OC4 * sizeof(int32_t))); | bias_data_ = reinterpret_cast<int32_t *>(malloc(C4NUM * OC4 * sizeof(int32_t))); | ||||
| if (bias_data_ == nullptr) { | if (bias_data_ == nullptr) { | ||||