| @@ -0,0 +1,110 @@ | |||||
| #ifdef __arm__ | |||||
| #ifndef __aarch64__ | |||||
| .text | |||||
| .align 5 | |||||
| .global ConvDwInt8PostAlign4 | |||||
| #ifndef __APPLE__ | |||||
| .type ConvDwInt8PostAlign4, %function | |||||
| #endif | |||||
| // void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, | |||||
| // int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); | |||||
| // r0: dst, r1: buffer, r2: num_pixels, r3: output_zp, r4: out_multiplier, | |||||
| // r5: left_shift, r6: right_shift, r7: acc_min, r8: acc_max | |||||
| ConvDwInt8PostAlign4: | |||||
| // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" | |||||
| // according to https://stackoverflow.com/questions/53625807 | |||||
| // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway | |||||
| // clang's rule seems more simple, though there are no subroutine calls here | |||||
| // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf | |||||
| push {r4-r8, r10} | |||||
| vpush {q4-q7} | |||||
| add sp, sp, #88 | |||||
| vdup.32 q15, r3 // output_zp | |||||
| ldr r4, [sp] // out_multiplier | |||||
| vdup.32 q14, r4 | |||||
| ldr r5, [sp, #4] // left_shift | |||||
| vdup.32 q13, r5 | |||||
| ldr r6, [sp, #8] // right_shift | |||||
| vdup.32 q12, r6 | |||||
| ldr r7, [sp, #12] // acc_min | |||||
| vdup.32 q11, r7 | |||||
| ldr r8, [sp, #16] // acc_max | |||||
| vdup.32 q10, r8 | |||||
| mov r10, r0 | |||||
| LoopDepth8: | |||||
| cmp r2, #8 | |||||
| blt End | |||||
| vld1.32 {q0}, [r1]! | |||||
| vshl.s32 q0, q0, q13 | |||||
| vqrdmulh.s32 q0, q0, q14 | |||||
| vand q4, q0, q12 | |||||
| vshr.s32 q4, q4, #31 | |||||
| vqadd.s32 q0, q0, q4 | |||||
| vrshl.s32 q0, q0, q12 | |||||
| vadd.i32 q0, q0, q15 | |||||
| vmax.s32 q0, q0, q11 | |||||
| vmin.s32 q0, q0, q10 | |||||
| vqmovn.s32 d4, q0 | |||||
| vld1.32 {q1}, [r1]! | |||||
| vshl.s32 q1, q1, q13 | |||||
| vqrdmulh.s32 q1, q1, q14 | |||||
| vand q4, q1, q12 | |||||
| vshr.s32 q4, q4, #31 | |||||
| vqadd.s32 q1, q1, q4 | |||||
| vrshl.s32 q1, q1, q12 | |||||
| vadd.i32 q1, q1, q15 | |||||
| vmax.s32 q1, q1, q11 | |||||
| vmin.s32 q1, q1, q10 | |||||
| vqmovn.s32 d5, q1 | |||||
| vqmovn.s16 d4, q2 | |||||
| vst1.8 {d4}, [r10]! | |||||
| sub r2, r2, #8 | |||||
| b LoopDepth8 | |||||
| LoopDepth4: | |||||
| cmp r2, #4 | |||||
| blt End | |||||
| vld1.32 {q0}, [r1]! | |||||
| vshl.s32 q0, q0, q13 | |||||
| vqrdmulh.s32 q0, q0, q14 | |||||
| vand q4, q0, q12 | |||||
| vshr.s32 q4, q4, #31 | |||||
| vqadd.s32 q0, q0, q4 | |||||
| vrshl.s32 q0, q0, q12 | |||||
| vadd.i32 q0, q0, q15 | |||||
| vmax.s32 q0, q0, q11 | |||||
| vmin.s32 q0, q0, q10 | |||||
| vqmovn.s32 d0, q0 | |||||
| vqmovn.s16 d0, q0 | |||||
| vst1.8 {d0[0]}, [r10]! | |||||
| vst1.8 {d0[1]}, [r10]! | |||||
| vst1.8 {d0[2]}, [r10]! | |||||
| vst1.8 {d0[3]}, [r10]! | |||||
| sub r2, r2, #4 | |||||
| b LoopDepth4 | |||||
| End: | |||||
| sub sp, sp, #88 | |||||
| vpop {q4-q7} | |||||
| pop {r4-r8, r10} | |||||
| bx lr | |||||
| #endif | |||||
| #endif | |||||
| @@ -0,0 +1,134 @@ | |||||
| #ifdef __arm__ | |||||
| #ifndef __aarch64__ | |||||
| .text | |||||
| .align 5 | |||||
| .global ConvDwInt8Row | |||||
| #ifndef __APPLE__ | |||||
| .type ConvDwInt8Row, %function | |||||
| #endif | |||||
| // void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, | |||||
| // int output_channel, int input_step, int8_t input_zp) | |||||
| // r0: output_ptr, r1: input_ptr, r2: weight_ptr, r3: num_pixels, | |||||
| // r4: output_channel, r5: input_step, r6: input_zp, | |||||
| ConvDwInt8Row: | |||||
| // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" | |||||
| // according to https://stackoverflow.com/questions/53625807 | |||||
| // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway | |||||
| // clang's rule seems more simple, though there are no subroutine calls here | |||||
| // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf | |||||
| push {r4-r8, r9-r12, lr} | |||||
| vpush {q4-q7} | |||||
| add sp, sp, #104 | |||||
| cmp r3, #0 | |||||
| beq End | |||||
| ldr r4, [sp] // channel | |||||
| ldr r5, [sp, #4] // input_step | |||||
| ldr r6, [sp, #8] // input_zp | |||||
| vdup.8 d30, r6 | |||||
| mov r7, r0 | |||||
| LoopPixel: | |||||
| mov r8, r1 // input | |||||
| mov r10, r2 // weight | |||||
| mov r11, r4 | |||||
| LoopDepth16In: | |||||
| cmp r11, #16 | |||||
| blt L8 | |||||
| sub r11, r11, #16 | |||||
| vld1.8 {q0}, [r8]! | |||||
| vld1.16 {q1, q2}, [r10]! // weight | |||||
| vsubl.s8 q3, d0, d30 // -zp | |||||
| vld1.32 {q4, q5}, [r0]! | |||||
| vmlal.s16 q4, d6, d2 | |||||
| vmlal.s16 q5, d7, d3 | |||||
| cmp r11, #16 | |||||
| blt LoopDepth16Out | |||||
| LoopDepth16: | |||||
| vst1.32 {q4, q5}, [r7]! | |||||
| vsubl.s8 q6, d1, d30 | |||||
| vld1.32 {q7, q8}, [r0]! | |||||
| vmlal.s16 q7, d12, d4 | |||||
| vmlal.s16 q8, d13, d5 | |||||
| vst1.32 {q7, q8}, [r7]! | |||||
| vld1.8 {q0}, [r8]! | |||||
| vld1.16 {q1, q2}, [r10]! // weight | |||||
| vsubl.s8 q3, d0, d30 // -zp | |||||
| vld1.32 {q4, q5}, [r0]! | |||||
| vmlal.s16 q4, d6, d2 | |||||
| vmlal.s16 q5, d7, d3 | |||||
| sub r11, r11, #16 | |||||
| cmp r11, #16 | |||||
| bge LoopDepth16 | |||||
| LoopDepth16Out: | |||||
| vst1.32 {q4, q5}, [r7]! | |||||
| vsubl.s8 q6, d1, d30 | |||||
| vld1.32 {q7, q8}, [r0]! | |||||
| vmlal.s16 q7, d12, d4 | |||||
| vmlal.s16 q8, d13, d5 | |||||
| vst1.32 {q7, q8}, [r7]! | |||||
| L8: | |||||
| cmp r11, #8 | |||||
| blt L0 | |||||
| LoopDepth8: | |||||
| vld1.8 {d0}, [r8]! | |||||
| vld1.16 {d2, d3}, [r10]! // weight | |||||
| vsubl.s8 q2, d0, d30 // -zp | |||||
| vld1.32 {q3}, [r0]! | |||||
| vmlal.s16 q3, d4, d2 | |||||
| vst1.32 {q3}, [r7]! | |||||
| vld1.32 {q4}, [r0]! | |||||
| vmlal.s16 q4, d5, d3 | |||||
| vst1.32 {q4}, [r7]! | |||||
| sub r11, r11, #8 | |||||
| cmp r11, #8 | |||||
| bge LoopDepth8 | |||||
| L0: | |||||
| cmp r11, #0 | |||||
| beq LoopDepthEnd | |||||
| LoopDepth0: | |||||
| ldrsb r12, [r8], #1 | |||||
| ldrsh r9, [r10], #2 | |||||
| sub r12, r12, r6 | |||||
| ldr lr, [r0], #4 | |||||
| smlabb r12, r12, r9, lr | |||||
| str r12, [r7], #4 | |||||
| subs r11, r11, #1 | |||||
| bne L0 | |||||
| LoopDepthEnd: | |||||
| add r1, r1, r5 | |||||
| subs r3, r3, #1 | |||||
| bne LoopPixel | |||||
| End: | |||||
| sub sp, sp, #104 | |||||
| vpop {q4-q7} | |||||
| pop {r4-r8, r9-r12, pc} | |||||
| #endif | |||||
| #endif | |||||
| @@ -32,6 +32,13 @@ void PostFuncInt8C8(const int32_t *in, const int32_t *bias, int8_t *out, size_t | |||||
| void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride, | void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride, | ||||
| int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, | int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, | ||||
| int32_t maxi); | int32_t maxi); | ||||
| #ifdef ENABLE_ARM | |||||
| void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, | |||||
| int output_channel, int input_step, int8_t input_zp); | |||||
| void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, | |||||
| int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); | |||||
| #endif | |||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res, | void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res, | ||||
| size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift, | size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift, | ||||
| @@ -50,10 +57,6 @@ void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, con | |||||
| size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp, | size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp, | ||||
| int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, | int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, | ||||
| int32_t *acc_min, int32_t *acc_max); | int32_t *acc_min, int32_t *acc_max); | ||||
| void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, | |||||
| int output_channel, int input_step, int8_t input_zp); | |||||
| void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, | |||||
| int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); | |||||
| void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, | void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, | ||||
| int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t acc_min, | int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t acc_min, | ||||
| int32_t acc_max); | int32_t acc_max); | ||||
| @@ -20,7 +20,7 @@ | |||||
| #include "nnacl/int8/common_func.h" | #include "nnacl/int8/common_func.h" | ||||
| /*conv depthwise int8 begin*/ | /*conv depthwise int8 begin*/ | ||||
| #ifndef ENABLE_ARM64 | |||||
| #ifndef ENABLE_ARM | |||||
| void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, | void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, | ||||
| int output_channel, int input_step, int8_t input_zp) { | int output_channel, int input_step, int8_t input_zp) { | ||||
| for (int i = 0; i < num_pixels; i++) { | for (int i = 0; i < num_pixels; i++) { | ||||
| @@ -59,7 +59,7 @@ void ConvDwInt8Post(int8_t *dst, int32_t *buffer, int output_w, int channel, int | |||||
| } else { | } else { | ||||
| int num_pixels = output_w * channel; | int num_pixels = output_w * channel; | ||||
| int align_num = 0; | int align_num = 0; | ||||
| #ifdef ENABLE_ARM64 | |||||
| #ifdef ENABLE_ARM | |||||
| align_num = num_pixels / 4 * 4; | align_num = num_pixels / 4 * 4; | ||||
| ConvDwInt8PostAlign4(dst, buffer, align_num, output_zp, out_multiplier[0], left_shift[0], right_shift[0], acc_min, | ConvDwInt8PostAlign4(dst, buffer, align_num, output_zp, out_multiplier[0], left_shift[0], right_shift[0], acc_min, | ||||
| acc_max); | acc_max); | ||||