| @@ -0,0 +1,116 @@ | |||
| #ifdef __arm__ | |||
| #ifndef __aarch64__ | |||
| .text | |||
| .align 5 | |||
| .global ConvDw3x3BorderPixelInt8 | |||
| #ifndef __APPLE__ | |||
| .type ConvDw3x3BorderPixelInt8, %function | |||
| #endif | |||
| // void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, | |||
| // size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, | |||
| // size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max) { | |||
| // r0: dst, r1: src, r2: weight, r3: bias, r4: height, r5: width, r6: in_kh_step, r7: in_kw_step, | |||
| // r8: channel, r9: in_zp, r10: out_zp, r11: out_multiplier, r12: left_shift, r13: right_shift | |||
| // r14: acc_min, r15: acc_max | |||
| ConvDw3x3BorderPixelInt8: | |||
| // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" | |||
| // according to https://stackoverflow.com/questions/53625807 | |||
| // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway | |||
| // clang's rule seems more simple, though there are no subroutine calls here | |||
| // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf | |||
| push {r4-r8, r9-r12, lr} | |||
| vpush {q4-q7} | |||
| add sp, sp, #104 | |||
| ldr r4, [sp] | |||
| ldr r5, [sp, #4] | |||
| ldr r6, [sp, #8] | |||
| ldr r7, [sp, #12] | |||
| ldr r8, [sp, #16] | |||
| ldrb r10, [sp, #20] // in_zp | |||
| vdup.8 d18, r10 | |||
| ldr r10, [sp, #24] // out_zp | |||
| vdup.32 q15, r10 | |||
| ldr r10, [sp, #28] // out_multiplier | |||
| vdup.32 q14, r10 | |||
| ldr r10, [sp, #32] // left_shift | |||
| vdup.32 q13, r10 | |||
| ldr r10, [sp, #36] // right_shift | |||
| vdup.32 q12, r10 | |||
| ldr r10, [sp, #40] // acc_min | |||
| vdup.32 q11, r10 | |||
| ldr r10, [sp, #44] // acc_max | |||
| vdup.32 q10, r10 | |||
| mov r4, #2 | |||
| mul lr, r8, r4 | |||
| LoopC: | |||
| mov r9, r1 | |||
| mov r10, r2 | |||
| ldr r4, [sp] | |||
| vld1.32 {q3}, [r3]! | |||
| vld1.32 {q4}, [r3]! | |||
| LoopH: | |||
| mov r11, r9 | |||
| mov r12, r10 | |||
| ldr r5, [sp, #4] | |||
| LoopW: | |||
| vld1.8 {d0}, [r11], r7 | |||
| vld1.16 {d2, d3}, [r12], lr // weight | |||
| vsubl.s8 q2, d0, d18 // -zp | |||
| vmlal.s16 q3, d4, d2 | |||
| vmlal.s16 q4, d5, d3 | |||
| subs r5, r5, #1 | |||
| bne LoopW | |||
| subs r4, r4, #1 | |||
| add r9, r9, r6 | |||
| mov r11, #3 | |||
| mul r5, lr, r11 | |||
| add r10, r10, r5 | |||
| bne LoopH | |||
| vshl.s32 q3, q3, q13 | |||
| vqrdmulh.s32 q3, q3, q14 | |||
| vand q5, q3, q12 | |||
| vshr.s32 q5, q5, #31 | |||
| vqadd.s32 q3, q3, q5 | |||
| vrshl.s32 q3, q3, q12 | |||
| vadd.i32 q3, q3, q15 | |||
| vmax.s32 q3, q3, q11 | |||
| vmin.s32 q3, q3, q10 | |||
| vqmovn.s32 d14, q3 | |||
| vshl.s32 q4, q4, q13 | |||
| vqrdmulh.s32 q4, q4, q14 | |||
| vand q6, q4, q12 | |||
| vshr.s32 q6, q6, #31 | |||
| vqadd.s32 q4, q4, q6 | |||
| vrshl.s32 q4, q4, q12 | |||
| vadd.i32 q4, q4, q15 | |||
| vmax.s32 q4, q4, q11 | |||
| vmin.s32 q4, q4, q10 | |||
| vqmovn.s32 d15, q4 | |||
| vqmovn.s16 d16, q7 | |||
| vst1.8 {d16}, [r0]! | |||
| add r1, r1, #8 | |||
| add r2, r2, #16 | |||
| sub r8, r8, #8 | |||
| cmp r8, #8 | |||
| bge LoopC | |||
| sub sp, sp, #104 | |||
| vpop {q4-q7} | |||
| pop {r4-r8, r9-r12, pc} | |||
| #endif | |||
| #endif | |||
| @@ -41,66 +41,128 @@ ConvDw3x3BorderPixelInt8: | |||
| mul x14, x13, x9 // x8 * 3 * 2 | |||
| LoopC: | |||
| ld1 {v23.4s}, [x3], #16 | |||
| ld1 {v24.4s}, [x3], #16 | |||
| mov x9, x1 | |||
| mov x10, x2 | |||
| mov x17, x4 // height | |||
| ld1 {v5.4s}, [x3], #16 | |||
| mov v3.16b, v5.16b | |||
| ld1 {v6.4s}, [x3], #16 | |||
| mov v4.16b, v6.16b | |||
| LoopH: | |||
| mov x11, x9 | |||
| mov x12, x10 | |||
| mov x18, x5 // width | |||
| LoopW: | |||
| ld1 {v0.8b}, [x11], x7 | |||
| ssubl v1.8h, v0.8b, v25.8b | |||
| ld1 {v2.8h}, [x12], x13 // weight | |||
| smlal v3.4s, v1.4h, v2.4h | |||
| smlal2 v4.4s, v1.8h, v2.8h | |||
| subs x18, x18, #1 | |||
| bne LoopW | |||
| subs x17, x17, #1 | |||
| add x9, x9, x6 | |||
| add x10, x10, x14 | |||
| bne LoopH | |||
| sqshl v3.4s, v3.4s, v28.4s | |||
| sqshl v4.4s, v4.4s, v28.4s | |||
| sqrdmulh v3.4s, v3.4s, v27.4s | |||
| sqrdmulh v4.4s, v4.4s, v27.4s | |||
| and v12.16b, v29.16b, v3.16b | |||
| sshr v12.4s, v12.4s, #31 | |||
| sqadd v3.4s, v3.4s, v12.4s | |||
| srshl v3.4s, v3.4s, v29.4s | |||
| and v11.16b, v29.16b, v4.16b | |||
| sshr v11.4s, v11.4s, #31 | |||
| sqadd v4.4s, v4.4s, v11.4s | |||
| srshl v4.4s, v4.4s, v29.4s | |||
| add v3.4s, v3.4s, v26.4s | |||
| add v4.4s, v4.4s, v26.4s | |||
| smax v3.4s, v3.4s, v30.4s | |||
| smax v4.4s, v4.4s, v30.4s | |||
| smin v3.4s, v3.4s, v31.4s | |||
| smin v4.4s, v4.4s, v31.4s | |||
| sqxtn v3.4h, v3.4s | |||
| sqxtn v4.4h, v4.4s | |||
| sqxtn v3.8b, v3.8h | |||
| sqxtn v4.8b, v4.8h | |||
| st1 {v3.s}[0], [x0], #4 | |||
| st1 {v4.s}[0], [x0], #4 | |||
| add x1, x1, #8 | |||
| add x2, x2, #16 | |||
| sub x8, x8, #8 | |||
| cmp x8, #8 | |||
| bge LoopC | |||
| cmp x4, #2 | |||
| blt LoopHW | |||
| LoopH2W2: | |||
| cmp x5, #2 | |||
| blt LoopHW | |||
| ld1 {v0.8b}, [x9], x7 | |||
| ssubl v0.8h, v0.8b, v25.8b | |||
| add x11, x1, x6 | |||
| ld1 {v4.8h}, [x10], x13 // weight | |||
| smlal v23.4s, v0.4h, v4.4h | |||
| smlal2 v24.4s, v0.8h, v4.8h | |||
| add x12, x2, x14 | |||
| ld1 {v1.8b}, [x9], x7 | |||
| ssubl v1.8h, v1.8b, v25.8b | |||
| ld1 {v5.8h}, [x10], x13 | |||
| smlal v23.4s, v1.4h, v5.4h | |||
| smlal2 v24.4s, v1.8h, v5.8h | |||
| add x15, x11, x6 | |||
| ld1 {v2.8b}, [x11], x7 | |||
| ssubl v2.8h, v2.8b, v25.8b | |||
| add x16, x12, x14 | |||
| ld1 {v6.8h}, [x12], x13 | |||
| smlal v23.4s, v2.4h, v6.4h | |||
| smlal2 v24.4s, v2.8h, v6.8h | |||
| ld1 {v3.8b}, [x11], x7 | |||
| ssubl v3.8h, v3.8b, v25.8b | |||
| ld1 {v7.8h}, [x12], x13 | |||
| smlal v23.4s, v3.4h, v7.4h | |||
| smlal2 v24.4s, v3.8h, v7.8h | |||
| cmp x5, #3 | |||
| beq LoopH2W3 | |||
| cmp x4, #3 | |||
| beq LoopH3W2 | |||
| b Post | |||
| LoopH2W3: | |||
| ld1 {v16.8b}, [x9], x7 | |||
| ssubl v16.8h, v16.8b, v25.8b | |||
| ld1 {v17.8h}, [x10], x13 | |||
| smlal v23.4s, v16.4h, v17.4h | |||
| smlal2 v24.4s, v16.8h, v17.8h | |||
| ld1 {v18.8b}, [x11], x7 | |||
| ssubl v18.8h, v18.8b, v25.8b | |||
| ld1 {v19.8h}, [x12], x13 | |||
| smlal v23.4s, v18.4h, v19.4h | |||
| smlal2 v24.4s, v18.8h, v19.8h | |||
| b Post | |||
| LoopH3W2: | |||
| ld1 {v16.8b}, [x15], x7 | |||
| ssubl v16.8h, v16.8b, v25.8b | |||
| ld1 {v17.8h}, [x16], x13 | |||
| smlal v23.4s, v16.4h, v17.4h | |||
| smlal2 v24.4s, v16.8h, v17.8h | |||
| ld1 {v18.8b}, [x15], x7 | |||
| ssubl v18.8h, v18.8b, v25.8b | |||
| ld1 {v19.8h}, [x16], x13 | |||
| smlal v23.4s, v18.4h, v19.4h | |||
| smlal2 v24.4s, v18.8h, v19.8h | |||
| b Post | |||
| LoopHW: | |||
| mov x9, x1 | |||
| mov x10, x2 | |||
| mov x17, x4 // height | |||
| LoopH: | |||
| mov x11, x9 | |||
| mov x12, x10 | |||
| mov x18, x5 // width | |||
| LoopW: | |||
| ld1 {v0.8b}, [x11], x7 | |||
| ssubl v1.8h, v0.8b, v25.8b | |||
| ld1 {v2.8h}, [x12], x13 // weight | |||
| smlal v23.4s, v1.4h, v2.4h | |||
| smlal2 v24.4s, v1.8h, v2.8h | |||
| subs x18, x18, #1 | |||
| bne LoopW | |||
| subs x17, x17, #1 | |||
| add x9, x9, x6 | |||
| add x10, x10, x14 | |||
| bne LoopH | |||
| Post: | |||
| sqshl v23.4s, v23.4s, v28.4s | |||
| sqshl v24.4s, v24.4s, v28.4s | |||
| sqrdmulh v23.4s, v23.4s, v27.4s | |||
| sqrdmulh v24.4s, v24.4s, v27.4s | |||
| and v12.16b, v29.16b, v23.16b | |||
| sshr v12.4s, v12.4s, #31 | |||
| sqadd v23.4s, v23.4s, v12.4s | |||
| srshl v23.4s, v23.4s, v29.4s | |||
| and v11.16b, v29.16b, v24.16b | |||
| sshr v11.4s, v11.4s, #31 | |||
| sqadd v24.4s, v24.4s, v11.4s | |||
| srshl v24.4s, v24.4s, v29.4s | |||
| add v23.4s, v23.4s, v26.4s | |||
| add v24.4s, v24.4s, v26.4s | |||
| smax v23.4s, v23.4s, v30.4s | |||
| smax v24.4s, v24.4s, v30.4s | |||
| smin v23.4s, v23.4s, v31.4s | |||
| smin v24.4s, v24.4s, v31.4s | |||
| sqxtn v23.4h, v23.4s | |||
| sqxtn v24.4h, v24.4s | |||
| sqxtn v23.8b, v23.8h | |||
| sqxtn v24.8b, v24.8h | |||
| st1 {v23.s}[0], [x0], #4 | |||
| st1 {v24.s}[0], [x0], #4 | |||
| add x1, x1, #8 | |||
| add x2, x2, #16 | |||
| sub x8, x8, #8 | |||
| cmp x8, #8 | |||
| bge LoopC | |||
| ret | |||
| #endif | |||
| @@ -47,6 +47,10 @@ void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, con | |||
| size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp, | |||
| int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, | |||
| int32_t *acc_min, int32_t *acc_max); | |||
| void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, | |||
| size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, | |||
| size_t out_zp, size_t out_multiplier, size_t left_shift, size_t right_shift, | |||
| size_t acc_min, size_t acc_max); | |||
| #endif | |||
| #ifdef ENABLE_ARM32 | |||
| @@ -67,10 +71,6 @@ void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *wei | |||
| void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, | |||
| size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | |||
| size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); | |||
| void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, | |||
| size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, | |||
| size_t out_zp, size_t out_multiplier, size_t left_shift, size_t right_shift, | |||
| size_t acc_min, size_t acc_max); | |||
| #endif | |||
| #ifdef __cplusplus | |||
| } | |||
| @@ -140,11 +140,8 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da | |||
| /*conv depthwise 3x3 int8 begin*/ | |||
| bool CheckIfUse3X3(const ConvParameter *conv_param, int channel) { | |||
| bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && | |||
| (conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) && | |||
| (conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) && | |||
| conv_param->stride_h_ == conv_param->stride_w_ && | |||
| (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) && | |||
| bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_h_ == 1 && | |||
| conv_param->stride_w_ == 1 && (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) && | |||
| (conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) && conv_param->pad_u_ == conv_param->pad_l_ && | |||
| conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && (channel % 8 == 0); | |||
| return use_3x3; | |||
| @@ -303,7 +300,7 @@ void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data | |||
| } | |||
| } | |||
| #ifndef ENABLE_ARM64 | |||
| #ifndef ENABLE_ARM | |||
| void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, | |||
| int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, | |||
| int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) { | |||
| @@ -172,7 +172,21 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::Tensor *> | |||
| auto act_quant_size = | |||
| MSMAX(inputs[kInputIndex]->GetQuantParams().size(), outputs[kOutputIndex]->GetQuantParams().size()); | |||
| if (act_quant_size == 1) { // per tensor | |||
| kernel = new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| auto conv_parm = reinterpret_cast<ConvParameter *>(opParameter); | |||
| auto channel = inputs[kWeightIndex]->shape()[0]; | |||
| auto weight_quant_size = inputs[kWeightIndex]->GetQuantParams().size(); | |||
| if (CheckIfUse3X3(conv_parm, channel) && weight_quant_size == 1) { | |||
| #ifdef ENABLE_ARM64 | |||
| kernel = | |||
| new (std::nothrow) kernel::ConvolutionDepthwise3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| #else | |||
| kernel = | |||
| new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| #endif | |||
| } else { | |||
| kernel = | |||
| new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } | |||
| } else { // per channel | |||
| kernel = | |||
| new (std::nothrow) kernel::ConvolutionDepthwiseSWInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||