diff --git a/src/layer/arm/convolution_1x1_int8.h b/src/layer/arm/convolution_1x1_int8.h index 1f144dcec..0ee908a4a 100644 --- a/src/layer/arm/convolution_1x1_int8.h +++ b/src/layer/arm/convolution_1x1_int8.h @@ -17,6 +17,9 @@ #endif // __ARM_NEON #if __aarch64__ +/* + * Convolution 1x1 quantized with int8,unroll 16 x 8 + */ static void conv1x1s1_int8_neon(const Mat &bottom_blob, Mat &top_blob, const Mat &_kernel, const Option& opt) { int inch = bottom_blob.c; @@ -25,44 +28,1031 @@ static void conv1x1s1_int8_neon(const Mat &bottom_blob, Mat &top_blob, const Mat int outh = top_blob.h; int outch = top_blob.c; - const float *kernel = _kernel; + const signed char* kernel = _kernel; + + int nn_outch = 0; + int remain_outch_start = 0; + + nn_outch = outch >> 3; + remain_outch_start = nn_outch << 3; #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) + for (int pp=0; pp 0; remain--) + int nn = size >> 3; + int remain = size & 7; + + int8x8_t _k_low = vld1_s8(kernel0); + int8x8_t _k_high = vld1_s8(kernel0+8); + int8x16_t _k0 = vcombine_s8(_k_low, _k_high); + + _k_low = vld1_s8(kernel1); + _k_high = vld1_s8(kernel1+8); + int8x16_t _k1 = vcombine_s8(_k_low, _k_high); + + _k_low = vld1_s8(kernel2); + _k_high = vld1_s8(kernel2+8); + int8x16_t _k2 = vcombine_s8(_k_low, _k_high); + + _k_low = vld1_s8(kernel3); + _k_high = vld1_s8(kernel3+8); + int8x16_t _k3 = vcombine_s8(_k_low, _k_high); + + _k_low = vld1_s8(kernel4); + _k_high = vld1_s8(kernel4+8); + int8x16_t _k4 = vcombine_s8(_k_low, _k_high); + + _k_low = vld1_s8(kernel5); + _k_high = vld1_s8(kernel5+8); + int8x16_t _k5 = vcombine_s8(_k_low, _k_high); + + _k_low = vld1_s8(kernel6); + _k_high = vld1_s8(kernel6+8); + int8x16_t _k6 = vcombine_s8(_k_low, _k_high); + + _k_low = vld1_s8(kernel7); + _k_high = vld1_s8(kernel7+8); + int8x16_t _k7 = vcombine_s8(_k_low, _k_high); + + if (nn > 0) { - //ToDo Neon - int sum0 = (int)*r0 * (int)kernel0[0] + (int)*r1 * (int)kernel0[1] + - (int)*r2 * (int)kernel0[2] + (int)*r3 * (int)kernel0[3] + - (int)*r4 * (int)kernel0[4] + (int)*r5 * (int)kernel0[5] + - (int)*r6 * (int)kernel0[6] + (int)*r7 * (int)kernel0[7]; + asm volatile( + "0: \n" + "prfm pldl1keep, [%9, #128] \n" + "ld1 {v8.8b}, [%9], #8 \n" // r0" + "prfm pldl1keep, [%10, #128] \n" + "ld1 {v9.8b}, [%10], #8 \n" // r1" + "prfm pldl1keep, [%11, #128] \n" + "ld1 {v10.8b}, [%11], #8 \n" // r2" + "prfm pldl1keep, [%12, #128] \n" + "ld1 {v11.8b}, [%12], #8 \n" // r3" + "prfm pldl1keep, [%13, #128] \n" + "ld1 {v12.8b}, [%13], #8 \n" // r4" + "prfm pldl1keep, [%14, #128] \n" + "ld1 {v13.8b}, [%14], #8 \n" // r5" + "prfm pldl1keep, [%15, #128] \n" + "ld1 {v14.8b}, [%15], #8 \n" // r6" + "prfm pldl1keep, [%16, #128] \n" + "ld1 {v15.8b}, [%16], #8 \n" // r7" + "prfm pldl1keep, [%17, #128] \n" + "ld1 {v16.8b}, [%17], #8 \n" // r8" + "prfm pldl1keep, [%18, #128] \n" + "ld1 {v17.8b}, [%18], #8 \n" // r9" + "prfm pldl1keep, [%19, #128] \n" + "ld1 {v18.8b}, [%19], #8 \n" // r10" + "prfm pldl1keep, [%20, #128] \n" + "ld1 {v19.8b}, [%20], #8 \n" // r11" + "prfm pldl1keep, [%21, #128] \n" + "ld1 {v20.8b}, [%21], #8 \n" // r12" + "prfm pldl1keep, [%22, #128] \n" + "ld1 {v21.8b}, [%22], #8 \n" // r13" + "prfm pldl1keep, [%23, #128] \n" + "ld1 {v22.8b}, [%23], #8 \n" // r14" + "prfm pldl1keep, [%24, #128] \n" + "ld1 {v23.8b}, [%24], #8 \n" // r15" + + "dup v24.8b, %50.b[0] \n" // k00 + "dup v25.8b, %50.b[1] \n" // k01 + "dup v26.8b, %50.b[2] \n" // k02 + "dup v27.8b, %50.b[3] \n" // k03 + "smull v28.8h, v8.8b, v24.8b \n" + "smlal v28.8h, v9.8b, v25.8b \n" + "smlal v28.8h, v10.8b, v26.8b \n" + "smlal v28.8h, v11.8b, v27.8b \n" + + "dup v24.8b, %50.b[4] \n" // k04 + "dup v25.8b, %50.b[5] \n" // k05 + "dup v26.8b, %50.b[6] \n" // k06 + "dup v27.8b, %50.b[7] \n" // k07 + "smlal v28.8h, v12.8b, v24.8b \n" + "smlal v28.8h, v13.8b, v25.8b \n" + "smlal v28.8h, v14.8b, v26.8b \n" + "smlal v28.8h, v15.8b, v27.8b \n" + + "dup v24.8b, %50.b[8] \n" // k08 + "dup v25.8b, %50.b[9] \n" // k09 + "dup v26.8b, %50.b[10] \n" // k10 + "dup v27.8b, %50.b[11] \n" // k11 + "smlal v28.8h, v16.8b, v24.8b \n" + "smlal v28.8h, v17.8b, v25.8b \n" + "smlal v28.8h, v18.8b, v26.8b \n" + "smlal v28.8h, v19.8b, v27.8b \n" + + "dup v24.8b, %50.b[12] \n" // k12 + "dup v25.8b, %50.b[13] \n" // k13 + "dup v26.8b, %50.b[14] \n" // k14 + "dup v27.8b, %50.b[15] \n" // k15 + "smlal v28.8h, v20.8b, v24.8b \n" + "smlal v28.8h, v21.8b, v25.8b \n" + "smlal v28.8h, v22.8b, v26.8b \n" + "smlal v28.8h, v23.8b, v27.8b \n" + + "prfm pldl1keep, [%1, #128] \n" + "ld1 {v29.4s, v30.4s}, [%1] \n" // sum0 + "saddw v29.4s, v29.4s, v28.4h \n" + "saddw2 v30.4s, v30.4s, v28.8h \n" + "st1 {v29.4s, v30.4s}, [%1], #32 \n" + //########################################### + "dup v24.8b, %51.b[0] \n" // k00 + "dup v25.8b, %51.b[1] \n" // k01 + "dup v26.8b, %51.b[2] \n" // k02 + "dup v27.8b, %51.b[3] \n" // k03 + "smull v28.8h, v8.8b, v24.8b \n" + "smlal v28.8h, v9.8b, v25.8b \n" + "smlal v28.8h, v10.8b, v26.8b \n" + "smlal v28.8h, v11.8b, v27.8b \n" + + "dup v24.8b, %51.b[4] \n" // k04 + "dup v25.8b, %51.b[5] \n" // k05 + "dup v26.8b, %51.b[6] \n" // k06 + "dup v27.8b, %51.b[7] \n" // k07 + "smlal v28.8h, v12.8b, v24.8b \n" + "smlal v28.8h, v13.8b, v25.8b \n" + "smlal v28.8h, v14.8b, v26.8b \n" + "smlal v28.8h, v15.8b, v27.8b \n" + + "dup v24.8b, %51.b[8] \n" // k08 + "dup v25.8b, %51.b[9] \n" // k09 + "dup v26.8b, %51.b[10] \n" // k10 + "dup v27.8b, %51.b[11] \n" // k11 + "smlal v28.8h, v16.8b, v24.8b \n" + "smlal v28.8h, v17.8b, v25.8b \n" + "smlal v28.8h, v18.8b, v26.8b \n" + "smlal v28.8h, v19.8b, v27.8b \n" + + "dup v24.8b, %51.b[12] \n" // k12 + "dup v25.8b, %51.b[13] \n" // k13 + "dup v26.8b, %51.b[14] \n" // k14 + "dup v27.8b, %51.b[15] \n" // k15 + "smlal v28.8h, v20.8b, v24.8b \n" + "smlal v28.8h, v21.8b, v25.8b \n" + "smlal v28.8h, v22.8b, v26.8b \n" + "smlal v28.8h, v23.8b, v27.8b \n" + + "prfm pldl1keep, [%2, #128] \n" + "ld1 {v29.4s, v30.4s}, [%2] \n" // sum1 + "saddw v29.4s, v29.4s, v28.4h \n" + "saddw2 v30.4s, v30.4s, v28.8h \n" + "st1 {v29.4s, v30.4s}, [%2], #32 \n" + //########################################### + "dup v24.8b, %52.b[0] \n" // k00 + "dup v25.8b, %52.b[1] \n" // k01 + "dup v26.8b, %52.b[2] \n" // k02 + "dup v27.8b, %52.b[3] \n" // k03 + "smull v28.8h, v8.8b, v24.8b \n" + "smlal v28.8h, v9.8b, v25.8b \n" + "smlal v28.8h, v10.8b, v26.8b \n" + "smlal v28.8h, v11.8b, v27.8b \n" + + "dup v24.8b, %52.b[4] \n" // k04 + "dup v25.8b, %52.b[5] \n" // k05 + "dup v26.8b, %52.b[6] \n" // k06 + "dup v27.8b, %52.b[7] \n" // k07 + "smlal v28.8h, v12.8b, v24.8b \n" + "smlal v28.8h, v13.8b, v25.8b \n" + "smlal v28.8h, v14.8b, v26.8b \n" + "smlal v28.8h, v15.8b, v27.8b \n" + + "dup v24.8b, %52.b[8] \n" // k08 + "dup v25.8b, %52.b[9] \n" // k09 + "dup v26.8b, %52.b[10] \n" // k10 + "dup v27.8b, %52.b[11] \n" // k11 + "smlal v28.8h, v16.8b, v24.8b \n" + "smlal v28.8h, v17.8b, v25.8b \n" + "smlal v28.8h, v18.8b, v26.8b \n" + "smlal v28.8h, v19.8b, v27.8b \n" + + "dup v24.8b, %52.b[12] \n" // k12 + "dup v25.8b, %52.b[13] \n" // k13 + "dup v26.8b, %52.b[14] \n" // k14 + "dup v27.8b, %52.b[15] \n" // k15 + "smlal v28.8h, v20.8b, v24.8b \n" + "smlal v28.8h, v21.8b, v25.8b \n" + "smlal v28.8h, v22.8b, v26.8b \n" + "smlal v28.8h, v23.8b, v27.8b \n" + + "prfm pldl1keep, [%3, #128] \n" + "ld1 {v29.4s, v30.4s}, [%3] \n" // sum2 + "saddw v29.4s, v29.4s, v28.4h \n" + "saddw2 v30.4s, v30.4s, v28.8h \n" + "st1 {v29.4s, v30.4s}, [%3], #32 \n" + //########################################### + "dup v24.8b, %53.b[0] \n" // k00 + "dup v25.8b, %53.b[1] \n" // k01 + "dup v26.8b, %53.b[2] \n" // k02 + "dup v27.8b, %53.b[3] \n" // k03 + "smull v28.8h, v8.8b, v24.8b \n" + "smlal v28.8h, v9.8b, v25.8b \n" + "smlal v28.8h, v10.8b, v26.8b \n" + "smlal v28.8h, v11.8b, v27.8b \n" + + "dup v24.8b, %53.b[4] \n" // k04 + "dup v25.8b, %53.b[5] \n" // k05 + "dup v26.8b, %53.b[6] \n" // k06 + "dup v27.8b, %53.b[7] \n" // k07 + "smlal v28.8h, v12.8b, v24.8b \n" + "smlal v28.8h, v13.8b, v25.8b \n" + "smlal v28.8h, v14.8b, v26.8b \n" + "smlal v28.8h, v15.8b, v27.8b \n" + + "dup v24.8b, %53.b[8] \n" // k08 + "dup v25.8b, %53.b[9] \n" // k09 + "dup v26.8b, %53.b[10] \n" // k10 + "dup v27.8b, %53.b[11] \n" // k11 + "smlal v28.8h, v16.8b, v24.8b \n" + "smlal v28.8h, v17.8b, v25.8b \n" + "smlal v28.8h, v18.8b, v26.8b \n" + "smlal v28.8h, v19.8b, v27.8b \n" + + "dup v24.8b, %53.b[12] \n" // k12 + "dup v25.8b, %53.b[13] \n" // k13 + "dup v26.8b, %53.b[14] \n" // k14 + "dup v27.8b, %53.b[15] \n" // k15 + "smlal v28.8h, v20.8b, v24.8b \n" + "smlal v28.8h, v21.8b, v25.8b \n" + "smlal v28.8h, v22.8b, v26.8b \n" + "smlal v28.8h, v23.8b, v27.8b \n" + + "prfm pldl1keep, [%4, #128] \n" + "ld1 {v29.4s, v30.4s}, [%4] \n" // sum3 + "saddw v29.4s, v29.4s, v28.4h \n" + "saddw2 v30.4s, v30.4s, v28.8h \n" + "st1 {v29.4s, v30.4s}, [%4], #32 \n" + //########################################### + "dup v24.8b, %54.b[0] \n" // k00 + "dup v25.8b, %54.b[1] \n" // k01 + "dup v26.8b, %54.b[2] \n" // k02 + "dup v27.8b, %54.b[3] \n" // k03 + "smull v28.8h, v8.8b, v24.8b \n" + "smlal v28.8h, v9.8b, v25.8b \n" + "smlal v28.8h, v10.8b, v26.8b \n" + "smlal v28.8h, v11.8b, v27.8b \n" + + "dup v24.8b, %54.b[4] \n" // k04 + "dup v25.8b, %54.b[5] \n" // k05 + "dup v26.8b, %54.b[6] \n" // k06 + "dup v27.8b, %54.b[7] \n" // k07 + "smlal v28.8h, v12.8b, v24.8b \n" + "smlal v28.8h, v13.8b, v25.8b \n" + "smlal v28.8h, v14.8b, v26.8b \n" + "smlal v28.8h, v15.8b, v27.8b \n" + + "dup v24.8b, %54.b[8] \n" // k08 + "dup v25.8b, %54.b[9] \n" // k09 + "dup v26.8b, %54.b[10] \n" // k10 + "dup v27.8b, %54.b[11] \n" // k11 + "smlal v28.8h, v16.8b, v24.8b \n" + "smlal v28.8h, v17.8b, v25.8b \n" + "smlal v28.8h, v18.8b, v26.8b \n" + "smlal v28.8h, v19.8b, v27.8b \n" + + "dup v24.8b, %54.b[12] \n" // k12 + "dup v25.8b, %54.b[13] \n" // k13 + "dup v26.8b, %54.b[14] \n" // k14 + "dup v27.8b, %54.b[15] \n" // k15 + "smlal v28.8h, v20.8b, v24.8b \n" + "smlal v28.8h, v21.8b, v25.8b \n" + "smlal v28.8h, v22.8b, v26.8b \n" + "smlal v28.8h, v23.8b, v27.8b \n" + + "prfm pldl1keep, [%5, #128] \n" + "ld1 {v29.4s, v30.4s}, [%5] \n" // sum4 + "saddw v29.4s, v29.4s, v28.4h \n" + "saddw2 v30.4s, v30.4s, v28.8h \n" + "st1 {v29.4s, v30.4s}, [%5], #32 \n" + //########################################### + "dup v24.8b, %55.b[0] \n" // k00 + "dup v25.8b, %55.b[1] \n" // k01 + "dup v26.8b, %55.b[2] \n" // k02 + "dup v27.8b, %55.b[3] \n" // k03 + "smull v28.8h, v8.8b, v24.8b \n" + "smlal v28.8h, v9.8b, v25.8b \n" + "smlal v28.8h, v10.8b, v26.8b \n" + "smlal v28.8h, v11.8b, v27.8b \n" + + "dup v24.8b, %55.b[4] \n" // k04 + "dup v25.8b, %55.b[5] \n" // k05 + "dup v26.8b, %55.b[6] \n" // k06 + "dup v27.8b, %55.b[7] \n" // k07 + "smlal v28.8h, v12.8b, v24.8b \n" + "smlal v28.8h, v13.8b, v25.8b \n" + "smlal v28.8h, v14.8b, v26.8b \n" + "smlal v28.8h, v15.8b, v27.8b \n" + + "dup v24.8b, %55.b[8] \n" // k08 + "dup v25.8b, %55.b[9] \n" // k09 + "dup v26.8b, %55.b[10] \n" // k10 + "dup v27.8b, %55.b[11] \n" // k11 + "smlal v28.8h, v16.8b, v24.8b \n" + "smlal v28.8h, v17.8b, v25.8b \n" + "smlal v28.8h, v18.8b, v26.8b \n" + "smlal v28.8h, v19.8b, v27.8b \n" + + "dup v24.8b, %55.b[12] \n" // k12 + "dup v25.8b, %55.b[13] \n" // k13 + "dup v26.8b, %55.b[14] \n" // k14 + "dup v27.8b, %55.b[15] \n" // k15 + "smlal v28.8h, v20.8b, v24.8b \n" + "smlal v28.8h, v21.8b, v25.8b \n" + "smlal v28.8h, v22.8b, v26.8b \n" + "smlal v28.8h, v23.8b, v27.8b \n" + + "prfm pldl1keep, [%6, #128] \n" + "ld1 {v29.4s, v30.4s}, [%6] \n" // sum5 + "saddw v29.4s, v29.4s, v28.4h \n" + "saddw2 v30.4s, v30.4s, v28.8h \n" + "st1 {v29.4s, v30.4s}, [%6], #32 \n" + //########################################### + "dup v24.8b, %56.b[0] \n" // k00 + "dup v25.8b, %56.b[1] \n" // k01 + "dup v26.8b, %56.b[2] \n" // k02 + "dup v27.8b, %56.b[3] \n" // k03 + "smull v28.8h, v8.8b, v24.8b \n" + "smlal v28.8h, v9.8b, v25.8b \n" + "smlal v28.8h, v10.8b, v26.8b \n" + "smlal v28.8h, v11.8b, v27.8b \n" + + "dup v24.8b, %56.b[4] \n" // k04 + "dup v25.8b, %56.b[5] \n" // k05 + "dup v26.8b, %56.b[6] \n" // k06 + "dup v27.8b, %56.b[7] \n" // k07 + "smlal v28.8h, v12.8b, v24.8b \n" + "smlal v28.8h, v13.8b, v25.8b \n" + "smlal v28.8h, v14.8b, v26.8b \n" + "smlal v28.8h, v15.8b, v27.8b \n" + + "dup v24.8b, %56.b[8] \n" // k08 + "dup v25.8b, %56.b[9] \n" // k09 + "dup v26.8b, %56.b[10] \n" // k10 + "dup v27.8b, %56.b[11] \n" // k11 + "smlal v28.8h, v16.8b, v24.8b \n" + "smlal v28.8h, v17.8b, v25.8b \n" + "smlal v28.8h, v18.8b, v26.8b \n" + "smlal v28.8h, v19.8b, v27.8b \n" + + "dup v24.8b, %56.b[12] \n" // k12 + "dup v25.8b, %56.b[13] \n" // k13 + "dup v26.8b, %56.b[14] \n" // k14 + "dup v27.8b, %56.b[15] \n" // k15 + "smlal v28.8h, v20.8b, v24.8b \n" + "smlal v28.8h, v21.8b, v25.8b \n" + "smlal v28.8h, v22.8b, v26.8b \n" + "smlal v28.8h, v23.8b, v27.8b \n" + + "prfm pldl1keep, [%7, #128] \n" + "ld1 {v29.4s, v30.4s}, [%7] \n" // sum6 + "saddw v29.4s, v29.4s, v28.4h \n" + "saddw2 v30.4s, v30.4s, v28.8h \n" + "st1 {v29.4s, v30.4s}, [%7], #32 \n" + //########################################### + "dup v24.8b, %57.b[0] \n" // k00 + "dup v25.8b, %57.b[1] \n" // k01 + "dup v26.8b, %57.b[2] \n" // k02 + "dup v27.8b, %57.b[3] \n" // k03 + "smull v28.8h, v8.8b, v24.8b \n" + "smlal v28.8h, v9.8b, v25.8b \n" + "smlal v28.8h, v10.8b, v26.8b \n" + "smlal v28.8h, v11.8b, v27.8b \n" + + "dup v24.8b, %57.b[4] \n" // k04 + "dup v25.8b, %57.b[5] \n" // k05 + "dup v26.8b, %57.b[6] \n" // k06 + "dup v27.8b, %57.b[7] \n" // k07 + "smlal v28.8h, v12.8b, v24.8b \n" + "smlal v28.8h, v13.8b, v25.8b \n" + "smlal v28.8h, v14.8b, v26.8b \n" + "smlal v28.8h, v15.8b, v27.8b \n" + + "dup v24.8b, %57.b[8] \n" // k08 + "dup v25.8b, %57.b[9] \n" // k09 + "dup v26.8b, %57.b[10] \n" // k10 + "dup v27.8b, %57.b[11] \n" // k11 + "smlal v28.8h, v16.8b, v24.8b \n" + "smlal v28.8h, v17.8b, v25.8b \n" + "smlal v28.8h, v18.8b, v26.8b \n" + "smlal v28.8h, v19.8b, v27.8b \n" + + "dup v24.8b, %57.b[12] \n" // k12 + "dup v25.8b, %57.b[13] \n" // k13 + "dup v26.8b, %57.b[14] \n" // k14 + "dup v27.8b, %57.b[15] \n" // k15 + "smlal v28.8h, v20.8b, v24.8b \n" + "smlal v28.8h, v21.8b, v25.8b \n" + "smlal v28.8h, v22.8b, v26.8b \n" + "smlal v28.8h, v23.8b, v27.8b \n" + + "prfm pldl1keep, [%8, #128] \n" + "ld1 {v29.4s, v30.4s}, [%8] \n" // sum7 + "saddw v29.4s, v29.4s, v28.4h \n" + "saddw2 v30.4s, v30.4s, v28.8h \n" + "st1 {v29.4s, v30.4s}, [%8], #32 \n" + //########################################### + "subs %w0, %w0, #1 \n" + "bne 0b \n" + : "=r"(nn), // %0 + "=r"(outptr0),// %1 + "=r"(outptr1),// %2 + "=r"(outptr2),// %3 + "=r"(outptr3),// %4 + "=r"(outptr4),// %5 + "=r"(outptr5),// %6 + "=r"(outptr6),// %7 + "=r"(outptr7),// %8 + "=r"(r0), // %9 + "=r"(r1), // %10 + "=r"(r2), // %11 + "=r"(r3), // %12 + "=r"(r4), // %13 + "=r"(r5), // %14 + "=r"(r6), // %15 + "=r"(r7), // %16 + "=r"(r8), // %17 + "=r"(r9), // %18 + "=r"(r10), // %19 + "=r"(r11), // %20 + "=r"(r12), // %21 + "=r"(r13), // %22 + "=r"(r14), // %23 + "=r"(r15) // %24 + : "0"(nn), + "1"(outptr0), + "2"(outptr1), + "3"(outptr2), + "4"(outptr3), + "5"(outptr4), + "6"(outptr5), + "7"(outptr6), + "8"(outptr7), + "9"(r0), + "10"(r1), + "11"(r2), + "12"(r3), + "13"(r4), + "14"(r5), + "15"(r6), + "16"(r7), + "17"(r8), + "18"(r9), + "19"(r10), + "20"(r11), + "21"(r12), + "22"(r13), + "23"(r14), + "24"(r15), + "w"(_k0), // %50 + "w"(_k1), // %51 + "w"(_k2), // %52 + "w"(_k3), // %53 + "w"(_k4), // %54 + "w"(_k5), // %55 + "w"(_k6), // %56 + "w"(_k7) // %57 + : "cc", "memory", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30" + ); + } + + if (remain >= 4) + { + remain -= 4; + + asm volatile( + "prfm pldl1keep, [%9, #128] \n" + "ld1 {v8.8b}, [%9], #8 \n" // r0" + "prfm pldl1keep, [%10, #128] \n" + "ld1 {v9.8b}, [%10], #8 \n" // r1" + "prfm pldl1keep, [%11, #128] \n" + "ld1 {v10.8b}, [%11], #8 \n" // r2" + "prfm pldl1keep, [%12, #128] \n" + "ld1 {v11.8b}, [%12], #8 \n" // r3" + "prfm pldl1keep, [%13, #128] \n" + "ld1 {v12.8b}, [%13], #8 \n" // r4" + "prfm pldl1keep, [%14, #128] \n" + "ld1 {v13.8b}, [%14], #8 \n" // r5" + "prfm pldl1keep, [%15, #128] \n" + "ld1 {v14.8b}, [%15], #8 \n" // r6" + "prfm pldl1keep, [%16, #128] \n" + "ld1 {v15.8b}, [%16], #8 \n" // r7" + "prfm pldl1keep, [%17, #128] \n" + "ld1 {v16.8b}, [%17], #8 \n" // r8" + "prfm pldl1keep, [%18, #128] \n" + "ld1 {v17.8b}, [%18], #8 \n" // r9" + "prfm pldl1keep, [%19, #128] \n" + "ld1 {v18.8b}, [%19], #8 \n" // r10" + "prfm pldl1keep, [%20, #128] \n" + "ld1 {v19.8b}, [%20], #8 \n" // r11" + "prfm pldl1keep, [%21, #128] \n" + "ld1 {v20.8b}, [%21], #8 \n" // r12" + "prfm pldl1keep, [%22, #128] \n" + "ld1 {v21.8b}, [%22], #8 \n" // r13" + "prfm pldl1keep, [%23, #128] \n" + "ld1 {v22.8b}, [%23], #8 \n" // r14" + "prfm pldl1keep, [%24, #128] \n" + "ld1 {v23.8b}, [%24], #8 \n" // r15" + + "dup v24.8b, %50.b[0] \n" // k00 + "dup v25.8b, %50.b[1] \n" // k01 + "dup v26.8b, %50.b[2] \n" // k02 + "dup v27.8b, %50.b[3] \n" // k03 + "smull v28.8h, v8.8b, v24.8b \n" + "smlal v28.8h, v9.8b, v25.8b \n" + "smlal v28.8h, v10.8b, v26.8b \n" + "smlal v28.8h, v11.8b, v27.8b \n" + + "dup v24.8b, %50.b[4] \n" // k04 + "dup v25.8b, %50.b[5] \n" // k05 + "dup v26.8b, %50.b[6] \n" // k06 + "dup v27.8b, %50.b[7] \n" // k07 + "smlal v28.8h, v12.8b, v24.8b \n" + "smlal v28.8h, v13.8b, v25.8b \n" + "smlal v28.8h, v14.8b, v26.8b \n" + "smlal v28.8h, v15.8b, v27.8b \n" + + "dup v24.8b, %50.b[8] \n" // k08 + "dup v25.8b, %50.b[9] \n" // k09 + "dup v26.8b, %50.b[10] \n" // k10 + "dup v27.8b, %50.b[11] \n" // k11 + "smlal v28.8h, v16.8b, v24.8b \n" + "smlal v28.8h, v17.8b, v25.8b \n" + "smlal v28.8h, v18.8b, v26.8b \n" + "smlal v28.8h, v19.8b, v27.8b \n" + + "dup v24.8b, %50.b[12] \n" // k12 + "dup v25.8b, %50.b[13] \n" // k13 + "dup v26.8b, %50.b[14] \n" // k14 + "dup v27.8b, %50.b[15] \n" // k15 + "smlal v28.8h, v20.8b, v24.8b \n" + "smlal v28.8h, v21.8b, v25.8b \n" + "smlal v28.8h, v22.8b, v26.8b \n" + "smlal v28.8h, v23.8b, v27.8b \n" + + "prfm pldl1keep, [%1, #128] \n" + "ld1 {v29.4s}, [%1] \n" // sum0 + "saddw v29.4s, v29.4s, v28.4h \n" + "st1 {v29.4s}, [%1], #16 \n" + //########################################### + "dup v24.8b, %51.b[0] \n" // k00 + "dup v25.8b, %51.b[1] \n" // k01 + "dup v26.8b, %51.b[2] \n" // k02 + "dup v27.8b, %51.b[3] \n" // k03 + "smull v28.8h, v8.8b, v24.8b \n" + "smlal v28.8h, v9.8b, v25.8b \n" + "smlal v28.8h, v10.8b, v26.8b \n" + "smlal v28.8h, v11.8b, v27.8b \n" + + "dup v24.8b, %51.b[4] \n" // k04 + "dup v25.8b, %51.b[5] \n" // k05 + "dup v26.8b, %51.b[6] \n" // k06 + "dup v27.8b, %51.b[7] \n" // k07 + "smlal v28.8h, v12.8b, v24.8b \n" + "smlal v28.8h, v13.8b, v25.8b \n" + "smlal v28.8h, v14.8b, v26.8b \n" + "smlal v28.8h, v15.8b, v27.8b \n" + + "dup v24.8b, %51.b[8] \n" // k08 + "dup v25.8b, %51.b[9] \n" // k09 + "dup v26.8b, %51.b[10] \n" // k10 + "dup v27.8b, %51.b[11] \n" // k11 + "smlal v28.8h, v16.8b, v24.8b \n" + "smlal v28.8h, v17.8b, v25.8b \n" + "smlal v28.8h, v18.8b, v26.8b \n" + "smlal v28.8h, v19.8b, v27.8b \n" + + "dup v24.8b, %51.b[12] \n" // k12 + "dup v25.8b, %51.b[13] \n" // k13 + "dup v26.8b, %51.b[14] \n" // k14 + "dup v27.8b, %51.b[15] \n" // k15 + "smlal v28.8h, v20.8b, v24.8b \n" + "smlal v28.8h, v21.8b, v25.8b \n" + "smlal v28.8h, v22.8b, v26.8b \n" + "smlal v28.8h, v23.8b, v27.8b \n" + + "prfm pldl1keep, [%2, #128] \n" + "ld1 {v29.4s}, [%2] \n" // sum1 + "saddw v29.4s, v29.4s, v28.4h \n" + "st1 {v29.4s}, [%2], #16 \n" + //########################################### + "dup v24.8b, %52.b[0] \n" // k00 + "dup v25.8b, %52.b[1] \n" // k01 + "dup v26.8b, %52.b[2] \n" // k02 + "dup v27.8b, %52.b[3] \n" // k03 + "smull v28.8h, v8.8b, v24.8b \n" + "smlal v28.8h, v9.8b, v25.8b \n" + "smlal v28.8h, v10.8b, v26.8b \n" + "smlal v28.8h, v11.8b, v27.8b \n" + + "dup v24.8b, %52.b[4] \n" // k04 + "dup v25.8b, %52.b[5] \n" // k05 + "dup v26.8b, %52.b[6] \n" // k06 + "dup v27.8b, %52.b[7] \n" // k07 + "smlal v28.8h, v12.8b, v24.8b \n" + "smlal v28.8h, v13.8b, v25.8b \n" + "smlal v28.8h, v14.8b, v26.8b \n" + "smlal v28.8h, v15.8b, v27.8b \n" + + "dup v24.8b, %52.b[8] \n" // k08 + "dup v25.8b, %52.b[9] \n" // k09 + "dup v26.8b, %52.b[10] \n" // k10 + "dup v27.8b, %52.b[11] \n" // k11 + "smlal v28.8h, v16.8b, v24.8b \n" + "smlal v28.8h, v17.8b, v25.8b \n" + "smlal v28.8h, v18.8b, v26.8b \n" + "smlal v28.8h, v19.8b, v27.8b \n" + + "dup v24.8b, %52.b[12] \n" // k12 + "dup v25.8b, %52.b[13] \n" // k13 + "dup v26.8b, %52.b[14] \n" // k14 + "dup v27.8b, %52.b[15] \n" // k15 + "smlal v28.8h, v20.8b, v24.8b \n" + "smlal v28.8h, v21.8b, v25.8b \n" + "smlal v28.8h, v22.8b, v26.8b \n" + "smlal v28.8h, v23.8b, v27.8b \n" + + "prfm pldl1keep, [%3, #128] \n" + "ld1 {v29.4s}, [%3] \n" // sum2 + "saddw v29.4s, v29.4s, v28.4h \n" + "st1 {v29.4s}, [%3], #16 \n" + //########################################### + "dup v24.8b, %53.b[0] \n" // k00 + "dup v25.8b, %53.b[1] \n" // k01 + "dup v26.8b, %53.b[2] \n" // k02 + "dup v27.8b, %53.b[3] \n" // k03 + "smull v28.8h, v8.8b, v24.8b \n" + "smlal v28.8h, v9.8b, v25.8b \n" + "smlal v28.8h, v10.8b, v26.8b \n" + "smlal v28.8h, v11.8b, v27.8b \n" + + "dup v24.8b, %53.b[4] \n" // k04 + "dup v25.8b, %53.b[5] \n" // k05 + "dup v26.8b, %53.b[6] \n" // k06 + "dup v27.8b, %53.b[7] \n" // k07 + "smlal v28.8h, v12.8b, v24.8b \n" + "smlal v28.8h, v13.8b, v25.8b \n" + "smlal v28.8h, v14.8b, v26.8b \n" + "smlal v28.8h, v15.8b, v27.8b \n" + + "dup v24.8b, %53.b[8] \n" // k08 + "dup v25.8b, %53.b[9] \n" // k09 + "dup v26.8b, %53.b[10] \n" // k10 + "dup v27.8b, %53.b[11] \n" // k11 + "smlal v28.8h, v16.8b, v24.8b \n" + "smlal v28.8h, v17.8b, v25.8b \n" + "smlal v28.8h, v18.8b, v26.8b \n" + "smlal v28.8h, v19.8b, v27.8b \n" + + "dup v24.8b, %53.b[12] \n" // k12 + "dup v25.8b, %53.b[13] \n" // k13 + "dup v26.8b, %53.b[14] \n" // k14 + "dup v27.8b, %53.b[15] \n" // k15 + "smlal v28.8h, v20.8b, v24.8b \n" + "smlal v28.8h, v21.8b, v25.8b \n" + "smlal v28.8h, v22.8b, v26.8b \n" + "smlal v28.8h, v23.8b, v27.8b \n" + + "prfm pldl1keep, [%4, #128] \n" + "ld1 {v29.4s}, [%4] \n" // sum3 + "saddw v29.4s, v29.4s, v28.4h \n" + "st1 {v29.4s}, [%4], #16 \n" + //########################################### + "dup v24.8b, %54.b[0] \n" // k00 + "dup v25.8b, %54.b[1] \n" // k01 + "dup v26.8b, %54.b[2] \n" // k02 + "dup v27.8b, %54.b[3] \n" // k03 + "smull v28.8h, v8.8b, v24.8b \n" + "smlal v28.8h, v9.8b, v25.8b \n" + "smlal v28.8h, v10.8b, v26.8b \n" + "smlal v28.8h, v11.8b, v27.8b \n" + + "dup v24.8b, %54.b[4] \n" // k04 + "dup v25.8b, %54.b[5] \n" // k05 + "dup v26.8b, %54.b[6] \n" // k06 + "dup v27.8b, %54.b[7] \n" // k07 + "smlal v28.8h, v12.8b, v24.8b \n" + "smlal v28.8h, v13.8b, v25.8b \n" + "smlal v28.8h, v14.8b, v26.8b \n" + "smlal v28.8h, v15.8b, v27.8b \n" + + "dup v24.8b, %54.b[8] \n" // k08 + "dup v25.8b, %54.b[9] \n" // k09 + "dup v26.8b, %54.b[10] \n" // k10 + "dup v27.8b, %54.b[11] \n" // k11 + "smlal v28.8h, v16.8b, v24.8b \n" + "smlal v28.8h, v17.8b, v25.8b \n" + "smlal v28.8h, v18.8b, v26.8b \n" + "smlal v28.8h, v19.8b, v27.8b \n" + + "dup v24.8b, %54.b[12] \n" // k12 + "dup v25.8b, %54.b[13] \n" // k13 + "dup v26.8b, %54.b[14] \n" // k14 + "dup v27.8b, %54.b[15] \n" // k15 + "smlal v28.8h, v20.8b, v24.8b \n" + "smlal v28.8h, v21.8b, v25.8b \n" + "smlal v28.8h, v22.8b, v26.8b \n" + "smlal v28.8h, v23.8b, v27.8b \n" + + "prfm pldl1keep, [%5, #128] \n" + "ld1 {v29.4s}, [%5] \n" // sum4 + "saddw v29.4s, v29.4s, v28.4h \n" + "st1 {v29.4s}, [%5], #16 \n" + //########################################### + "dup v24.8b, %55.b[0] \n" // k00 + "dup v25.8b, %55.b[1] \n" // k01 + "dup v26.8b, %55.b[2] \n" // k02 + "dup v27.8b, %55.b[3] \n" // k03 + "smull v28.8h, v8.8b, v24.8b \n" + "smlal v28.8h, v9.8b, v25.8b \n" + "smlal v28.8h, v10.8b, v26.8b \n" + "smlal v28.8h, v11.8b, v27.8b \n" + + "dup v24.8b, %55.b[4] \n" // k04 + "dup v25.8b, %55.b[5] \n" // k05 + "dup v26.8b, %55.b[6] \n" // k06 + "dup v27.8b, %55.b[7] \n" // k07 + "smlal v28.8h, v12.8b, v24.8b \n" + "smlal v28.8h, v13.8b, v25.8b \n" + "smlal v28.8h, v14.8b, v26.8b \n" + "smlal v28.8h, v15.8b, v27.8b \n" + + "dup v24.8b, %55.b[8] \n" // k08 + "dup v25.8b, %55.b[9] \n" // k09 + "dup v26.8b, %55.b[10] \n" // k10 + "dup v27.8b, %55.b[11] \n" // k11 + "smlal v28.8h, v16.8b, v24.8b \n" + "smlal v28.8h, v17.8b, v25.8b \n" + "smlal v28.8h, v18.8b, v26.8b \n" + "smlal v28.8h, v19.8b, v27.8b \n" + + "dup v24.8b, %55.b[12] \n" // k12 + "dup v25.8b, %55.b[13] \n" // k13 + "dup v26.8b, %55.b[14] \n" // k14 + "dup v27.8b, %55.b[15] \n" // k15 + "smlal v28.8h, v20.8b, v24.8b \n" + "smlal v28.8h, v21.8b, v25.8b \n" + "smlal v28.8h, v22.8b, v26.8b \n" + "smlal v28.8h, v23.8b, v27.8b \n" + + "prfm pldl1keep, [%6, #128] \n" + "ld1 {v29.4s}, [%6] \n" // sum5 + "saddw v29.4s, v29.4s, v28.4h \n" + "st1 {v29.4s}, [%6], #16 \n" + //########################################### + "dup v24.8b, %56.b[0] \n" // k00 + "dup v25.8b, %56.b[1] \n" // k01 + "dup v26.8b, %56.b[2] \n" // k02 + "dup v27.8b, %56.b[3] \n" // k03 + "smull v28.8h, v8.8b, v24.8b \n" + "smlal v28.8h, v9.8b, v25.8b \n" + "smlal v28.8h, v10.8b, v26.8b \n" + "smlal v28.8h, v11.8b, v27.8b \n" + + "dup v24.8b, %56.b[4] \n" // k04 + "dup v25.8b, %56.b[5] \n" // k05 + "dup v26.8b, %56.b[6] \n" // k06 + "dup v27.8b, %56.b[7] \n" // k07 + "smlal v28.8h, v12.8b, v24.8b \n" + "smlal v28.8h, v13.8b, v25.8b \n" + "smlal v28.8h, v14.8b, v26.8b \n" + "smlal v28.8h, v15.8b, v27.8b \n" + + "dup v24.8b, %56.b[8] \n" // k08 + "dup v25.8b, %56.b[9] \n" // k09 + "dup v26.8b, %56.b[10] \n" // k10 + "dup v27.8b, %56.b[11] \n" // k11 + "smlal v28.8h, v16.8b, v24.8b \n" + "smlal v28.8h, v17.8b, v25.8b \n" + "smlal v28.8h, v18.8b, v26.8b \n" + "smlal v28.8h, v19.8b, v27.8b \n" + + "dup v24.8b, %56.b[12] \n" // k12 + "dup v25.8b, %56.b[13] \n" // k13 + "dup v26.8b, %56.b[14] \n" // k14 + "dup v27.8b, %56.b[15] \n" // k15 + "smlal v28.8h, v20.8b, v24.8b \n" + "smlal v28.8h, v21.8b, v25.8b \n" + "smlal v28.8h, v22.8b, v26.8b \n" + "smlal v28.8h, v23.8b, v27.8b \n" + + "prfm pldl1keep, [%7, #128] \n" + "ld1 {v29.4s}, [%7] \n" // sum6 + "saddw v29.4s, v29.4s, v28.4h \n" + "st1 {v29.4s}, [%7], #16 \n" + //########################################### + "dup v24.8b, %57.b[0] \n" // k00 + "dup v25.8b, %57.b[1] \n" // k01 + "dup v26.8b, %57.b[2] \n" // k02 + "dup v27.8b, %57.b[3] \n" // k03 + "smull v28.8h, v8.8b, v24.8b \n" + "smlal v28.8h, v9.8b, v25.8b \n" + "smlal v28.8h, v10.8b, v26.8b \n" + "smlal v28.8h, v11.8b, v27.8b \n" + + "dup v24.8b, %57.b[4] \n" // k04 + "dup v25.8b, %57.b[5] \n" // k05 + "dup v26.8b, %57.b[6] \n" // k06 + "dup v27.8b, %57.b[7] \n" // k07 + "smlal v28.8h, v12.8b, v24.8b \n" + "smlal v28.8h, v13.8b, v25.8b \n" + "smlal v28.8h, v14.8b, v26.8b \n" + "smlal v28.8h, v15.8b, v27.8b \n" + + "dup v24.8b, %57.b[8] \n" // k08 + "dup v25.8b, %57.b[9] \n" // k09 + "dup v26.8b, %57.b[10] \n" // k10 + "dup v27.8b, %57.b[11] \n" // k11 + "smlal v28.8h, v16.8b, v24.8b \n" + "smlal v28.8h, v17.8b, v25.8b \n" + "smlal v28.8h, v18.8b, v26.8b \n" + "smlal v28.8h, v19.8b, v27.8b \n" + + "dup v24.8b, %57.b[12] \n" // k12 + "dup v25.8b, %57.b[13] \n" // k13 + "dup v26.8b, %57.b[14] \n" // k14 + "dup v27.8b, %57.b[15] \n" // k15 + "smlal v28.8h, v20.8b, v24.8b \n" + "smlal v28.8h, v21.8b, v25.8b \n" + "smlal v28.8h, v22.8b, v26.8b \n" + "smlal v28.8h, v23.8b, v27.8b \n" + + "prfm pldl1keep, [%8, #128] \n" + "ld1 {v29.4s}, [%8] \n" // sum7 + "saddw v29.4s, v29.4s, v28.4h \n" + "st1 {v29.4s}, [%8], #16 \n" + //########################################### + "sub %w9, %w9, #4 \n" + "sub %w10, %w10, #4 \n" + "sub %w11, %w11, #4 \n" + "sub %w12, %w12, #4 \n" + "sub %w13, %w13, #4 \n" + "sub %w14, %w14, #4 \n" + "sub %w15, %w15, #4 \n" + "sub %w16, %w16, #4 \n" + "sub %w17, %w17, #4 \n" + "sub %w18, %w18, #4 \n" + "sub %w19, %w19, #4 \n" + "sub %w20, %w20, #4 \n" + "sub %w21, %w21, #4 \n" + "sub %w22, %w22, #4 \n" + "sub %w23, %w23, #4 \n" + "sub %w24, %w24, #4 \n" + : "=r"(nn), // %0 + "=r"(outptr0),// %1 + "=r"(outptr1),// %2 + "=r"(outptr2),// %3 + "=r"(outptr3),// %4 + "=r"(outptr4),// %5 + "=r"(outptr5),// %6 + "=r"(outptr6),// %7 + "=r"(outptr7),// %8 + "=r"(r0), // %9 + "=r"(r1), // %10 + "=r"(r2), // %11 + "=r"(r3), // %12 + "=r"(r4), // %13 + "=r"(r5), // %14 + "=r"(r6), // %15 + "=r"(r7), // %16 + "=r"(r8), // %17 + "=r"(r9), // %18 + "=r"(r10), // %19 + "=r"(r11), // %20 + "=r"(r12), // %21 + "=r"(r13), // %22 + "=r"(r14), // %23 + "=r"(r15) // %24 + : "0"(nn), + "1"(outptr0), + "2"(outptr1), + "3"(outptr2), + "4"(outptr3), + "5"(outptr4), + "6"(outptr5), + "7"(outptr6), + "8"(outptr7), + "9"(r0), + "10"(r1), + "11"(r2), + "12"(r3), + "13"(r4), + "14"(r5), + "15"(r6), + "16"(r7), + "17"(r8), + "18"(r9), + "19"(r10), + "20"(r11), + "21"(r12), + "22"(r13), + "23"(r14), + "24"(r15), + "w"(_k0), // %50 + "w"(_k1), // %51 + "w"(_k2), // %52 + "w"(_k3), // %53 + "w"(_k4), // %54 + "w"(_k5), // %55 + "w"(_k6), // %56 + "w"(_k7) // %57 + : "cc", "memory", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30" + ); + } + + for (; remain>0; remain--) + { + // TODO neon optimize + int sum0 = (int)*r0 * kernel0[0] + *r1 * kernel0[1] + *r2 * kernel0[2] + *r3 * kernel0[3] + *r4 * kernel0[4] + *r5 * kernel0[5] + *r6 * kernel0[6] + *r7 * kernel0[7] + *r8 * kernel0[8] + *r9 * kernel0[9] + *r10 * kernel0[10] + *r11 * kernel0[11] + *r12 * kernel0[12] + *r13 * kernel0[13] + *r14 * kernel0[14] + *r15 * kernel0[15]; + int sum1 = (int)*r0 * kernel1[0] + *r1 * kernel1[1] + *r2 * kernel1[2] + *r3 * kernel1[3] + *r4 * kernel1[4] + *r5 * kernel1[5] + *r6 * kernel1[6] + *r7 * kernel1[7] + *r8 * kernel1[8] + *r9 * kernel1[9] + *r10 * kernel1[10] + *r11 * kernel1[11] + *r12 * kernel1[12] + *r13 * kernel1[13] + *r14 * kernel1[14] + *r15 * kernel1[15]; + int sum2 = (int)*r0 * kernel2[0] + *r1 * kernel2[1] + *r2 * kernel2[2] + *r3 * kernel2[3] + *r4 * kernel2[4] + *r5 * kernel2[5] + *r6 * kernel2[6] + *r7 * kernel2[7] + *r8 * kernel2[8] + *r9 * kernel2[9] + *r10 * kernel2[10] + *r11 * kernel2[11] + *r12 * kernel2[12] + *r13 * kernel2[13] + *r14 * kernel2[14] + *r15 * kernel2[15]; + int sum3 = (int)*r0 * kernel3[0] + *r1 * kernel3[1] + *r2 * kernel3[2] + *r3 * kernel3[3] + *r4 * kernel3[4] + *r5 * kernel3[5] + *r6 * kernel3[6] + *r7 * kernel3[7] + *r8 * kernel3[8] + *r9 * kernel3[9] + *r10 * kernel3[10] + *r11 * kernel3[11] + *r12 * kernel3[12] + *r13 * kernel3[13] + *r14 * kernel3[14] + *r15 * kernel3[15]; + int sum4 = (int)*r0 * kernel4[0] + *r1 * kernel4[1] + *r2 * kernel4[2] + *r3 * kernel4[3] + *r4 * kernel4[4] + *r5 * kernel4[5] + *r6 * kernel4[6] + *r7 * kernel4[7] + *r8 * kernel4[8] + *r9 * kernel4[9] + *r10 * kernel4[10] + *r11 * kernel4[11] + *r12 * kernel4[12] + *r13 * kernel4[13] + *r14 * kernel4[14] + *r15 * kernel4[15]; + int sum5 = (int)*r0 * kernel5[0] + *r1 * kernel5[1] + *r2 * kernel5[2] + *r3 * kernel5[3] + *r4 * kernel5[4] + *r5 * kernel5[5] + *r6 * kernel5[6] + *r7 * kernel5[7] + *r8 * kernel5[8] + *r9 * kernel5[9] + *r10 * kernel5[10] + *r11 * kernel5[11] + *r12 * kernel5[12] + *r13 * kernel5[13] + *r14 * kernel5[14] + *r15 * kernel5[15]; + int sum6 = (int)*r0 * kernel6[0] + *r1 * kernel6[1] + *r2 * kernel6[2] + *r3 * kernel6[3] + *r4 * kernel6[4] + *r5 * kernel6[5] + *r6 * kernel6[6] + *r7 * kernel6[7] + *r8 * kernel6[8] + *r9 * kernel6[9] + *r10 * kernel6[10] + *r11 * kernel6[11] + *r12 * kernel6[12] + *r13 * kernel6[13] + *r14 * kernel6[14] + *r15 * kernel6[15]; + int sum7 = (int)*r0 * kernel7[0] + *r1 * kernel7[1] + *r2 * kernel7[2] + *r3 * kernel7[3] + *r4 * kernel7[4] + *r5 * kernel7[5] + *r6 * kernel7[6] + *r7 * kernel7[7] + *r8 * kernel7[8] + *r9 * kernel7[9] + *r10 * kernel7[10] + *r11 * kernel7[11] + *r12 * kernel7[12] + *r13 * kernel7[13] + *r14 * kernel7[14] + *r15 * kernel7[15]; *outptr0 += sum0; + *outptr1 += sum1; + *outptr2 += sum2; + *outptr3 += sum3; + *outptr4 += sum4; + *outptr5 += sum5; + *outptr6 += sum6; + *outptr7 += sum7; r0++; r1++; @@ -72,33 +1062,348 @@ static void conv1x1s1_int8_neon(const Mat &bottom_blob, Mat &top_blob, const Mat r5++; r6++; r7++; + r8++; + r9++; + r10++; + r11++; + r12++; + r13++; + r14++; + r15++; outptr0++; + outptr1++; + outptr2++; + outptr3++; + outptr4++; + outptr5++; + outptr6++; + outptr7++; } } for (; q 0; remain--) + int nn = size >> 3; + int remain = size & 7; + + int8x8_t _k0 = vdup_n_s8(k0); + int8x8_t _k1 = vdup_n_s8(k1); + int8x8_t _k2 = vdup_n_s8(k2); + int8x8_t _k3 = vdup_n_s8(k3); + int8x8_t _k4 = vdup_n_s8(k4); + int8x8_t _k5 = vdup_n_s8(k5); + int8x8_t _k6 = vdup_n_s8(k6); + int8x8_t _k7 = vdup_n_s8(k7); + + for (; nn>0; nn--) + { + int8x8_t _r0 = vld1_s8(r0); + + int32x4_t _out0 = vld1q_s32(outptr0); + int32x4_t _out0n = vld1q_s32(outptr0+4); + int32x4_t _out1 = vld1q_s32(outptr1); + int32x4_t _out1n = vld1q_s32(outptr1+4); + int32x4_t _out2 = vld1q_s32(outptr2); + int32x4_t _out2n = vld1q_s32(outptr2+4); + int32x4_t _out3 = vld1q_s32(outptr3); + int32x4_t _out3n = vld1q_s32(outptr3+4); + int32x4_t _out4 = vld1q_s32(outptr4); + int32x4_t _out4n = vld1q_s32(outptr4+4); + int32x4_t _out5 = vld1q_s32(outptr5); + int32x4_t _out5n = vld1q_s32(outptr5+4); + int32x4_t _out6 = vld1q_s32(outptr6); + int32x4_t _out6n = vld1q_s32(outptr6+4); + int32x4_t _out7 = vld1q_s32(outptr7); + int32x4_t _out7n = vld1q_s32(outptr7+4); + + int16x8_t _out0_s16 = vmull_s8(_r0, _k0); + int16x8_t _out1_s16 = vmull_s8(_r0, _k1); + int16x8_t _out2_s16 = vmull_s8(_r0, _k2); + int16x8_t _out3_s16 = vmull_s8(_r0, _k3); + int16x8_t _out4_s16 = vmull_s8(_r0, _k4); + int16x8_t _out5_s16 = vmull_s8(_r0, _k5); + int16x8_t _out6_s16 = vmull_s8(_r0, _k6); + int16x8_t _out7_s16 = vmull_s8(_r0, _k7); + + _out0 = vaddw_s16(_out0, vget_low_s16(_out0_s16)); + _out0n = vaddw_s16(_out0n, vget_high_s16(_out0_s16)); + _out1 = vaddw_s16(_out1, vget_low_s16(_out1_s16)); + _out1n = vaddw_s16(_out1n, vget_high_s16(_out1_s16)); + _out2 = vaddw_s16(_out2, vget_low_s16(_out2_s16)); + _out2n = vaddw_s16(_out2n, vget_high_s16(_out2_s16)); + _out3 = vaddw_s16(_out3, vget_low_s16(_out3_s16)); + _out3n = vaddw_s16(_out3n, vget_high_s16(_out3_s16)); + _out4 = vaddw_s16(_out4, vget_low_s16(_out4_s16)); + _out4n = vaddw_s16(_out4n, vget_high_s16(_out4_s16)); + _out5 = vaddw_s16(_out5, vget_low_s16(_out5_s16)); + _out5n = vaddw_s16(_out5n, vget_high_s16(_out5_s16)); + _out6 = vaddw_s16(_out6, vget_low_s16(_out6_s16)); + _out6n = vaddw_s16(_out6n, vget_high_s16(_out6_s16)); + _out7 = vaddw_s16(_out7, vget_low_s16(_out7_s16)); + _out7n = vaddw_s16(_out7n, vget_high_s16(_out7_s16)); + + vst1q_s32(outptr0, _out0); + vst1q_s32(outptr0+4, _out0n); + vst1q_s32(outptr1, _out1); + vst1q_s32(outptr1+4, _out1n); + vst1q_s32(outptr2, _out2); + vst1q_s32(outptr2+4, _out2n); + vst1q_s32(outptr3, _out3); + vst1q_s32(outptr3+4, _out3n); + vst1q_s32(outptr4, _out4); + vst1q_s32(outptr4+4, _out4n); + vst1q_s32(outptr5, _out5); + vst1q_s32(outptr5+4, _out5n); + vst1q_s32(outptr6, _out6); + vst1q_s32(outptr6+4, _out6n); + vst1q_s32(outptr7, _out7); + vst1q_s32(outptr7+4, _out7n); + + r0 += 8; + outptr0 += 8; + outptr1 += 8; + outptr2 += 8; + outptr3 += 8; + outptr4 += 8; + outptr5 += 8; + outptr6 += 8; + outptr7 += 8; + } + + for (; remain>0; remain--) { - int sum0 = (int)(*r0) * (int)k0; + // TODO neon optimize + int sum0 = (int)*r0 * k0; + int sum1 = (int)*r0 * k1; + int sum2 = (int)*r0 * k2; + int sum3 = (int)*r0 * k3; + int sum4 = (int)*r0 * k4; + int sum5 = (int)*r0 * k5; + int sum6 = (int)*r0 * k6; + int sum7 = (int)*r0 * k7; *outptr0 += sum0; + *outptr1 += sum1; + *outptr2 += sum2; + *outptr3 += sum3; + *outptr4 += sum4; + *outptr5 += sum5; + *outptr6 += sum6; + *outptr7 += sum7; r0++; outptr0++; + outptr1++; + outptr2++; + outptr3++; + outptr4++; + outptr5++; + outptr6++; + outptr7++; } } } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p=remain_outch_start; p> 3; + int remain = size & 7; + + int8x8_t _k0 = vdup_n_s8(k0); + int8x8_t _k1 = vdup_n_s8(k1); + int8x8_t _k2 = vdup_n_s8(k2); + int8x8_t _k3 = vdup_n_s8(k3); + int8x8_t _k4 = vdup_n_s8(k4); + int8x8_t _k5 = vdup_n_s8(k5); + int8x8_t _k6 = vdup_n_s8(k6); + int8x8_t _k7 = vdup_n_s8(k7); + + for (; nn>0; nn--) + { + int8x8_t _r0 = vld1_s8(r0); + int8x8_t _r1 = vld1_s8(r1); + int8x8_t _r2 = vld1_s8(r2); + int8x8_t _r3 = vld1_s8(r3); + int8x8_t _r4 = vld1_s8(r4); + int8x8_t _r5 = vld1_s8(r5); + int8x8_t _r6 = vld1_s8(r6); + int8x8_t _r7 = vld1_s8(r7); + + int32x4_t _out0 = vld1q_s32(outptr); + int32x4_t _out0n = vld1q_s32(outptr+4); + + int16x8_t _out0_s16 = vmull_s8(_r0, _k0); + _out0_s16 = vmlal_s8(_out0_s16, _r1, _k1); + _out0_s16 = vmlal_s8(_out0_s16, _r2, _k2); + _out0_s16 = vmlal_s8(_out0_s16, _r3, _k3); + _out0_s16 = vmlal_s8(_out0_s16, _r4, _k4); + _out0_s16 = vmlal_s8(_out0_s16, _r5, _k5); + _out0_s16 = vmlal_s8(_out0_s16, _r6, _k6); + _out0_s16 = vmlal_s8(_out0_s16, _r7, _k7); + + _out0 = vaddw_s16(_out0, vget_low_s16(_out0_s16)); + _out0n = vaddw_s16(_out0n, vget_high_s16(_out0_s16)); + + vst1q_s32(outptr, _out0); + vst1q_s32(outptr+4, _out0n); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + r7 += 8; + outptr += 8; + } + + for (; remain>0; remain--) + { + int sum = (int)*r0 * k0; + int sum1 = (int)*r1 * k1; + int sum2 = (int)*r2 * k2; + int sum3 = (int)*r3 * k3; + int sum4 = (int)*r4 * k4; + int sum5 = (int)*r5 * k5; + int sum6 = (int)*r6 * k6; + int sum7 = (int)*r7 * k7; + + *outptr += sum + sum1 + sum2 + sum3 + sum4 + sum5 + sum6 + sum7; + + r0++; + r1++; + r2++; + r3++; + r4++; + r5++; + r6++; + r7++; + outptr++; + } + + } + + for (; q> 3; + int remain = size & 7; + + int8x8_t _k0 = vdup_n_s8(k0); + + if (nn > 0) + { + int8x8_t _r0 = vld1_s8(r0); + + int32x4_t _out0 = vld1q_s32(outptr); + int32x4_t _out0n = vld1q_s32(outptr+4); + + int16x8_t _out0_s16 = vmull_s8(_r0, _k0); + + _out0 = vaddw_s16(_out0, vget_low_s16(_out0_s16)); + _out0n = vaddw_s16(_out0n, vget_high_s16(_out0_s16)); + + vst1q_s32(outptr, _out0); + vst1q_s32(outptr+4, _out0n); + + r0 += 8; + outptr += 8; + } + + for (; remain>0; remain--) + { + int sum = (int)*r0 * k0; + + *outptr += sum; + + r0++; + outptr++; + } + } + } } #else // __aarch64__ /* diff --git a/src/layer/arm/convolution_3x3_int8.h b/src/layer/arm/convolution_3x3_int8.h index 190b09812..b155d5fba 100644 --- a/src/layer/arm/convolution_3x3_int8.h +++ b/src/layer/arm/convolution_3x3_int8.h @@ -73,42 +73,195 @@ static void conv3x3s1_transform_kernel_int8_neon(const Mat& _kernel, Mat& kernel static void conv3x3s1_int8_neon(const Mat &bottom_blob, Mat &top_blob, const Mat &_kernel, const Option& opt) { int w = bottom_blob.w; - //int h = bottom_blob.h; int inch = bottom_blob.c; int outw = top_blob.w; int outh = top_blob.h; int outch = top_blob.c; - const signed char *kernel = _kernel; + const signed char* kernel = _kernel; + + int nn_outch = outch >> 1; + int remain_outch_start = nn_outch << 1; #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) + for (int pp=0; pp < nn_outch; pp++) { + int p = pp * 2; + Mat out0 = top_blob.channel(p); + Mat out1 = top_blob.channel(p+1); out0.fill(0); + out1.fill(0); - const signed char *kernel0 = (const signed char *)kernel + p * inch * 9; - - for (int q = 0; q < inch; q++) + const signed char* kernel0 = (const signed char *)kernel + p * inch * 9; + const signed char* kernel1 = (const signed char *)kernel + (p + 1) * inch * 9; + + for (int q=0; q> 3; + int remain = outw & 7; + + for (; nn > 0; nn--) + { + // outch 0 + int8x8_t _r0 = vld1_s8(r0); + int8x8_t _r0n = vld1_s8(r0+8); + int8x8_t _r01 = vext_s8(_r0, _r0n, 1); + int8x8_t _r02 = vext_s8(_r0, _r0n, 2); + + int16x8_t _sum0 = vmull_s8(_r0, _k00); + _sum0 = vmlal_s8(_sum0, _r01, _k01); + _sum0 = vmlal_s8(_sum0, _r02, _k02); + + int8x8_t _r1 = vld1_s8(r1); + int8x8_t _r1n = vld1_s8(r1+8); + int8x8_t _r11 = vext_s8(_r1, _r1n, 1); + int8x8_t _r12 = vext_s8(_r1, _r1n, 2); + _sum0 = vmlal_s8(_sum0, _r1, _k03); + _sum0 = vmlal_s8(_sum0, _r11, _k04); + _sum0 = vmlal_s8(_sum0, _r12, _k05); + + int16x8_t _sum1 = vmull_s8(_r1, _k00); + _sum1 = vmlal_s8(_sum1, _r11, _k01); + _sum1 = vmlal_s8(_sum1, _r12, _k02); + + int8x8_t _r2 = vld1_s8(r2); + int8x8_t _r2n = vld1_s8(r2+8); + int8x8_t _r21 = vext_s8(_r2, _r2n, 1); + int8x8_t _r22 = vext_s8(_r2, _r2n, 2); + _sum0 = vmlal_s8(_sum0, _r2, _k06); + _sum0 = vmlal_s8(_sum0, _r21, _k07); + _sum0 = vmlal_s8(_sum0, _r22, _k08); + + _sum1 = vmlal_s8(_sum1, _r2, _k03); + _sum1 = vmlal_s8(_sum1, _r21, _k04); + _sum1 = vmlal_s8(_sum1, _r22, _k05); + + int8x8_t _r3 = vld1_s8(r3); + int8x8_t _r3n = vld1_s8(r3+8); + int8x8_t _r31 = vext_s8(_r3, _r3n, 1); + int8x8_t _r32 = vext_s8(_r3, _r3n, 2); + _sum1 = vmlal_s8(_sum1, _r3, _k06); + _sum1 = vmlal_s8(_sum1, _r31, _k07); + _sum1 = vmlal_s8(_sum1, _r32, _k08); + + int32x4_t sum0_s32 = vld1q_s32(outptr0); + int32x4_t sum0n_s32 = vld1q_s32(outptr0+4); + + sum0_s32 = vaddw_s16(sum0_s32, vget_low_s16(_sum0)); + sum0n_s32 = vaddw_s16(sum0n_s32, vget_high_s16(_sum0)); + + vst1q_s32(outptr0, sum0_s32); + vst1q_s32(outptr0+4, sum0n_s32); + + int32x4_t sum1_s32 = vld1q_s32(outptr0n); + int32x4_t sum1n_s32 = vld1q_s32(outptr0n+4); + + sum1_s32 = vaddw_s16(sum1_s32, vget_low_s16(_sum1)); + sum1n_s32 = vaddw_s16(sum1n_s32, vget_high_s16(_sum1)); + + vst1q_s32(outptr0n, sum1_s32); + vst1q_s32(outptr0n+4, sum1n_s32); + + // outch 1 + _sum0 = vmull_s8(_r0, _k10); + _sum0 = vmlal_s8(_sum0, _r01, _k11); + _sum0 = vmlal_s8(_sum0, _r02, _k12); + + _sum0 = vmlal_s8(_sum0, _r1, _k13); + _sum0 = vmlal_s8(_sum0, _r11, _k14); + _sum0 = vmlal_s8(_sum0, _r12, _k15); + + _sum0 = vmlal_s8(_sum0, _r2, _k16); + _sum0 = vmlal_s8(_sum0, _r21, _k17); + _sum0 = vmlal_s8(_sum0, _r22, _k18); + + _sum1 = vmull_s8(_r1, _k10); + _sum1 = vmlal_s8(_sum1, _r11, _k11); + _sum1 = vmlal_s8(_sum1, _r12, _k12); + + _sum1 = vmlal_s8(_sum1, _r2, _k13); + _sum1 = vmlal_s8(_sum1, _r21, _k14); + _sum1 = vmlal_s8(_sum1, _r22, _k15); + + _sum1 = vmlal_s8(_sum1, _r3, _k16); + _sum1 = vmlal_s8(_sum1, _r31, _k17); + _sum1 = vmlal_s8(_sum1, _r32, _k18); + + sum0_s32 = vld1q_s32(outptr1); + sum0n_s32 = vld1q_s32(outptr1+4); + + sum0_s32 = vaddw_s16(sum0_s32, vget_low_s16(_sum0)); + sum0n_s32 = vaddw_s16(sum0n_s32, vget_high_s16(_sum0)); + + vst1q_s32(outptr1, sum0_s32); + vst1q_s32(outptr1+4, sum0n_s32); + + sum1_s32 = vld1q_s32(outptr1n); + sum1n_s32 = vld1q_s32(outptr1n+4); + + sum1_s32 = vaddw_s16(sum1_s32, vget_low_s16(_sum1)); + sum1n_s32 = vaddw_s16(sum1n_s32, vget_high_s16(_sum1)); + + vst1q_s32(outptr1n, sum1_s32); + vst1q_s32(outptr1n+4, sum1n_s32); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + outptr0 += 8; + outptr1 += 8; + outptr0n += 8; + outptr1n += 8; + } - for (; remain > 0; remain--) + for (; remain>0; remain--) { int sum0 = 0; + int sum0n = 0; + int sum1 = 0; + int sum1n = 0; + //ToDo Neon sum0 += (int)r0[0] * kernel0[0]; sum0 += (int)r0[1] * kernel0[1]; sum0 += (int)r0[2] * kernel0[2]; @@ -119,12 +272,166 @@ static void conv3x3s1_int8_neon(const Mat &bottom_blob, Mat &top_blob, const Mat sum0 += (int)r2[1] * kernel0[7]; sum0 += (int)r2[2] * kernel0[8]; + sum1 += (int)r0[0] * kernel1[0]; + sum1 += (int)r0[1] * kernel1[1]; + sum1 += (int)r0[2] * kernel1[2]; + sum1 += (int)r1[0] * kernel1[3]; + sum1 += (int)r1[1] * kernel1[4]; + sum1 += (int)r1[2] * kernel1[5]; + sum1 += (int)r2[0] * kernel1[6]; + sum1 += (int)r2[1] * kernel1[7]; + sum1 += (int)r2[2] * kernel1[8]; + + sum0n += (int)r1[0] * kernel0[0]; + sum0n += (int)r1[1] * kernel0[1]; + sum0n += (int)r1[2] * kernel0[2]; + sum0n += (int)r2[0] * kernel0[3]; + sum0n += (int)r2[1] * kernel0[4]; + sum0n += (int)r2[2] * kernel0[5]; + sum0n += (int)r3[0] * kernel0[6]; + sum0n += (int)r3[1] * kernel0[7]; + sum0n += (int)r3[2] * kernel0[8]; + + sum1n += (int)r1[0] * kernel1[0]; + sum1n += (int)r1[1] * kernel1[1]; + sum1n += (int)r1[2] * kernel1[2]; + sum1n += (int)r2[0] * kernel1[3]; + sum1n += (int)r2[1] * kernel1[4]; + sum1n += (int)r2[2] * kernel1[5]; + sum1n += (int)r3[0] * kernel1[6]; + sum1n += (int)r3[1] * kernel1[7]; + sum1n += (int)r3[2] * kernel1[8]; + *outptr0 += sum0; + *outptr1 += sum1; + *outptr0n += sum0n; + *outptr1n += sum1n; r0++; r1++; r2++; + r3++; outptr0++; + outptr1++; + outptr0n++; + outptr1n++; + } + + r0 += 2 + w; + r1 += 2 + w; + r2 += 2 + w; + r3 += 2 + w; + + outptr0 += outw; + outptr1 += outw; + outptr0n += outw; + outptr1n += outw; + } + + for (; i < outh; i++) + { + int nn = outw >> 3; + int remain = outw & 7; + + for (; nn > 0; nn--) + { + // outch 0 + int8x8_t _r0 = vld1_s8(r0); + int8x8_t _r0n = vld1_s8(r0+8); + int8x8_t _r01 = vext_s8(_r0, _r0n, 1); + int8x8_t _r02 = vext_s8(_r0, _r0n, 2); + + int16x8_t _sum0 = vmull_s8(_r0, _k00); + _sum0 = vmlal_s8(_sum0, _r01, _k01); + _sum0 = vmlal_s8(_sum0, _r02, _k02); + + int8x8_t _r1 = vld1_s8(r1); + int8x8_t _r1n = vld1_s8(r1+8); + int8x8_t _r11 = vext_s8(_r1, _r1n, 1); + int8x8_t _r12 = vext_s8(_r1, _r1n, 2); + _sum0 = vmlal_s8(_sum0, _r1, _k03); + _sum0 = vmlal_s8(_sum0, _r11, _k04); + _sum0 = vmlal_s8(_sum0, _r12, _k05); + + int8x8_t _r2 = vld1_s8(r2); + int8x8_t _r2n = vld1_s8(r2+8); + int8x8_t _r21 = vext_s8(_r2, _r2n, 1); + int8x8_t _r22 = vext_s8(_r2, _r2n, 2); + _sum0 = vmlal_s8(_sum0, _r2, _k06); + _sum0 = vmlal_s8(_sum0, _r21, _k07); + _sum0 = vmlal_s8(_sum0, _r22, _k08); + + int32x4_t sum0_s32 = vld1q_s32(outptr0); + int32x4_t sum0n_s32 = vld1q_s32(outptr0+4); + + sum0_s32 = vaddw_s16(sum0_s32, vget_low_s16(_sum0)); + sum0n_s32 = vaddw_s16(sum0n_s32, vget_high_s16(_sum0)); + + vst1q_s32(outptr0, sum0_s32); + vst1q_s32(outptr0+4, sum0n_s32); + + // outch 1 + _sum0 = vmull_s8(_r0, _k10); + _sum0 = vmlal_s8(_sum0, _r01, _k11); + _sum0 = vmlal_s8(_sum0, _r02, _k12); + + _sum0 = vmlal_s8(_sum0, _r1, _k13); + _sum0 = vmlal_s8(_sum0, _r11, _k14); + _sum0 = vmlal_s8(_sum0, _r12, _k15); + + _sum0 = vmlal_s8(_sum0, _r2, _k16); + _sum0 = vmlal_s8(_sum0, _r21, _k17); + _sum0 = vmlal_s8(_sum0, _r22, _k18); + + sum0_s32 = vld1q_s32(outptr1); + sum0n_s32 = vld1q_s32(outptr1+4); + + sum0_s32 = vaddw_s16(sum0_s32, vget_low_s16(_sum0)); + sum0n_s32 = vaddw_s16(sum0n_s32, vget_high_s16(_sum0)); + + vst1q_s32(outptr1, sum0_s32); + vst1q_s32(outptr1+4, sum0n_s32); + + r0 += 8; + r1 += 8; + r2 += 8; + outptr0 += 8; + outptr1 += 8; + } + + for (; remain>0; remain--) + { + int sum0 = 0; + int sum1 = 0; + + sum0 += (int)r0[0] * kernel0[0]; + sum0 += (int)r0[1] * kernel0[1]; + sum0 += (int)r0[2] * kernel0[2]; + sum0 += (int)r1[0] * kernel0[3]; + sum0 += (int)r1[1] * kernel0[4]; + sum0 += (int)r1[2] * kernel0[5]; + sum0 += (int)r2[0] * kernel0[6]; + sum0 += (int)r2[1] * kernel0[7]; + sum0 += (int)r2[2] * kernel0[8]; + + sum1 += (int)r0[0] * kernel1[0]; + sum1 += (int)r0[1] * kernel1[1]; + sum1 += (int)r0[2] * kernel1[2]; + sum1 += (int)r1[0] * kernel1[3]; + sum1 += (int)r1[1] * kernel1[4]; + sum1 += (int)r1[2] * kernel1[5]; + sum1 += (int)r2[0] * kernel1[6]; + sum1 += (int)r2[1] * kernel1[7]; + sum1 += (int)r2[2] * kernel1[8]; + + *outptr0 += sum0; + *outptr1 += sum1; + + r0++; + r1++; + r2++; + outptr0++; + outptr1++; } r0 += 2; @@ -133,14 +440,245 @@ static void conv3x3s1_int8_neon(const Mat &bottom_blob, Mat &top_blob, const Mat } kernel0 += 9; + kernel1 += 9; } } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p=remain_outch_start; p> 3; + int remain = outw & 7; + + for (; nn > 0; nn--) + { + int8x8_t _r0 = vld1_s8(r0); + int8x8_t _r0n = vld1_s8(r0+8); + int8x8_t _r01 = vext_s8(_r0, _r0n, 1); + int8x8_t _r02 = vext_s8(_r0, _r0n, 2); + + int16x8_t _sum0 = vmull_s8(_r0, _k00); + _sum0 = vmlal_s8(_sum0, _r01, _k01); + _sum0 = vmlal_s8(_sum0, _r02, _k02); + + int8x8_t _r1 = vld1_s8(r1); + int8x8_t _r1n = vld1_s8(r1+8); + int8x8_t _r11 = vext_s8(_r1, _r1n, 1); + int8x8_t _r12 = vext_s8(_r1, _r1n, 2); + _sum0 = vmlal_s8(_sum0, _r1, _k03); + _sum0 = vmlal_s8(_sum0, _r11, _k04); + _sum0 = vmlal_s8(_sum0, _r12, _k05); + + int16x8_t _sum1 = vmull_s8(_r1, _k00); + _sum1 = vmlal_s8(_sum1, _r11, _k01); + _sum1 = vmlal_s8(_sum1, _r12, _k02); + + int8x8_t _r2 = vld1_s8(r2); + int8x8_t _r2n = vld1_s8(r2+8); + int8x8_t _r21 = vext_s8(_r2, _r2n, 1); + int8x8_t _r22 = vext_s8(_r2, _r2n, 2); + _sum0 = vmlal_s8(_sum0, _r2, _k06); + _sum0 = vmlal_s8(_sum0, _r21, _k07); + _sum0 = vmlal_s8(_sum0, _r22, _k08); + + _sum1 = vmlal_s8(_sum1, _r2, _k03); + _sum1 = vmlal_s8(_sum1, _r21, _k04); + _sum1 = vmlal_s8(_sum1, _r22, _k05); + + int8x8_t _r3 = vld1_s8(r3); + int8x8_t _r3n = vld1_s8(r3+8); + int8x8_t _r31 = vext_s8(_r3, _r3n, 1); + int8x8_t _r32 = vext_s8(_r3, _r3n, 2); + _sum1 = vmlal_s8(_sum1, _r3, _k06); + _sum1 = vmlal_s8(_sum1, _r31, _k07); + _sum1 = vmlal_s8(_sum1, _r32, _k08); + + int32x4_t sum0_s32 = vld1q_s32(outptr0); + int32x4_t sum0n_s32 = vld1q_s32(outptr0+4); + + sum0_s32 = vaddw_s16(sum0_s32, vget_low_s16(_sum0)); + sum0n_s32 = vaddw_s16(sum0n_s32, vget_high_s16(_sum0)); + + vst1q_s32(outptr0, sum0_s32); + vst1q_s32(outptr0+4, sum0n_s32); + + int32x4_t sum1_s32 = vld1q_s32(outptr0n); + int32x4_t sum1n_s32 = vld1q_s32(outptr0n+4); + + sum1_s32 = vaddw_s16(sum1_s32, vget_low_s16(_sum1)); + sum1n_s32 = vaddw_s16(sum1n_s32, vget_high_s16(_sum1)); + + vst1q_s32(outptr0n, sum1_s32); + vst1q_s32(outptr0n+4, sum1n_s32); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + outptr0 += 8; + outptr0n += 8; + } + + for (; remain>0; remain--) + { + // Todo neon + int sum0 = 0; + int sum0n = 0; + + sum0 += (int)r0[0] * kernel0[0]; + sum0 += (int)r0[1] * kernel0[1]; + sum0 += (int)r0[2] * kernel0[2]; + sum0 += (int)r1[0] * kernel0[3]; + sum0 += (int)r1[1] * kernel0[4]; + sum0 += (int)r1[2] * kernel0[5]; + sum0 += (int)r2[0] * kernel0[6]; + sum0 += (int)r2[1] * kernel0[7]; + sum0 += (int)r2[2] * kernel0[8]; + + sum0n += (int)r1[0] * kernel0[0]; + sum0n += (int)r1[1] * kernel0[1]; + sum0n += (int)r1[2] * kernel0[2]; + sum0n += (int)r2[0] * kernel0[3]; + sum0n += (int)r2[1] * kernel0[4]; + sum0n += (int)r2[2] * kernel0[5]; + sum0n += (int)r3[0] * kernel0[6]; + sum0n += (int)r3[1] * kernel0[7]; + sum0n += (int)r3[2] * kernel0[8]; + + *outptr0 += sum0; + *outptr0n += sum0n; + + r0++; + r1++; + r2++; + r3++; + outptr0++; + outptr0n++; + } + + r0 += 2 + w; + r1 += 2 + w; + r2 += 2 + w; + r3 += 2 + w; + + outptr0 += outw; + outptr0n += outw; + } + + for (; i < outh; i++) + { + int nn = outw >> 3; + int remain = outw & 7; + + for (; nn > 0; nn--) + { + int8x8_t _r0 = vld1_s8(r0); + int8x8_t _r0n = vld1_s8(r0+8); + int8x8_t _r01 = vext_s8(_r0, _r0n, 1); + int8x8_t _r02 = vext_s8(_r0, _r0n, 2); + + int16x8_t _sum0 = vmull_s8(_r0, _k00); + _sum0 = vmlal_s8(_sum0, _r01, _k01); + _sum0 = vmlal_s8(_sum0, _r02, _k02); + + int8x8_t _r1 = vld1_s8(r1); + int8x8_t _r1n = vld1_s8(r1+8); + int8x8_t _r11 = vext_s8(_r1, _r1n, 1); + int8x8_t _r12 = vext_s8(_r1, _r1n, 2); + _sum0 = vmlal_s8(_sum0, _r1, _k03); + _sum0 = vmlal_s8(_sum0, _r11, _k04); + _sum0 = vmlal_s8(_sum0, _r12, _k05); + + int8x8_t _r2 = vld1_s8(r2); + int8x8_t _r2n = vld1_s8(r2+8); + int8x8_t _r21 = vext_s8(_r2, _r2n, 1); + int8x8_t _r22 = vext_s8(_r2, _r2n, 2); + _sum0 = vmlal_s8(_sum0, _r2, _k06); + _sum0 = vmlal_s8(_sum0, _r21, _k07); + _sum0 = vmlal_s8(_sum0, _r22, _k08); + + int32x4_t sum0_s32 = vld1q_s32(outptr0); + int32x4_t sum0n_s32 = vld1q_s32(outptr0+4); + + sum0_s32 = vaddw_s16(sum0_s32, vget_low_s16(_sum0)); + sum0n_s32 = vaddw_s16(sum0n_s32, vget_high_s16(_sum0)); + + vst1q_s32(outptr0, sum0_s32); + vst1q_s32(outptr0+4, sum0n_s32); + + r0 += 8; + r1 += 8; + r2 += 8; + outptr0 += 8; + } + + for (; remain>0; remain--) + { + int sum0 = 0; + + sum0 += (int)r0[0] * kernel0[0]; + sum0 += (int)r0[1] * kernel0[1]; + sum0 += (int)r0[2] * kernel0[2]; + sum0 += (int)r1[0] * kernel0[3]; + sum0 += (int)r1[1] * kernel0[4]; + sum0 += (int)r1[2] * kernel0[5]; + sum0 += (int)r2[0] * kernel0[6]; + sum0 += (int)r2[1] * kernel0[7]; + sum0 += (int)r2[2] * kernel0[8]; + + *outptr0 += sum0; + + r0++; + r1++; + r2++; + outptr0++; + } + + r0 += 2; + r1 += 2; + r2 += 2; + } + kernel0 += 9; + } + } } -static void conv3x3s2_int8_neon(const Mat &bottom_blob, Mat &top_blob, const Mat &_kernel, const Option& opt) +static void conv3x3s2_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Option& opt) { int w = bottom_blob.w; - //int h = bottom_blob.h; + int h = bottom_blob.h; int inch = bottom_blob.c; int outw = top_blob.w; @@ -149,52 +687,429 @@ static void conv3x3s2_int8_neon(const Mat &bottom_blob, Mat &top_blob, const Mat const int tailstep = w - 2 * outw + w; - const signed char *kernel = _kernel; + const signed char* kernel = _kernel; + + int nn_outch = outch >> 2; + int remain_outch_start = nn_outch << 2; #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) + for (int pp=0; pp < nn_outch; pp++) { - Mat out0 = top_blob.channel(p); + int p = pp * 4; - out0.fill(0); + Mat out0 = top_blob.channel(p); + Mat out1 = top_blob.channel(p + 1); + Mat out2 = top_blob.channel(p + 2); + Mat out3 = top_blob.channel(p + 3); + + out0.fill(0.f); + out1.fill(0.f); + out2.fill(0.f); + out3.fill(0.f); - const signed char *kernel0 = (const signed char *)kernel + p * inch * 9; + const signed char* kernel0 = (const signed char*)kernel + p * inch * 9; + const signed char* kernel1 = (const signed char*)kernel + (p + 1) * inch * 9; + const signed char* kernel2 = (const signed char*)kernel + (p + 2) * inch * 9; + const signed char* kernel3 = (const signed char*)kernel + (p + 3) * inch * 9; - for (int q = 0; q < inch; q++) + for (int q=0; q> 3; + int remain = outw & 7; - for (; remain > 0; remain--) + if (nn > 0) + { + asm volatile( + "0: \n" + // r0 + "prfm pldl1keep, [%5, #128] \n" + "ld2 {v4.8b, v5.8b}, [%5], #16 \n" + "ld2 {v6.8b, v7.8b}, [%5] \n" + "ext v8.8b, v4.8b, v6.8b, #1 \n" + + "dup v9.8b, %16.b[0] \n" + "dup v10.8b, %17.b[0] \n" + "dup v11.8b, %18.b[0] \n" + "dup v12.8b, %19.b[0] \n" + + "smull v13.8h, v4.8b, v9.8b \n" + "smull v14.8h, v4.8b, v10.8b \n" + "smull v15.8h, v4.8b, v11.8b \n" + "smull v16.8h, v4.8b, v12.8b \n" + + "dup v9.8b, %16.b[1] \n" + "dup v10.8b, %17.b[1] \n" + "dup v11.8b, %18.b[1] \n" + "dup v12.8b, %19.b[1] \n" + + "smlal v13.8h, v5.8b, v9.8b \n" + "smlal v14.8h, v5.8b, v10.8b \n" + "smlal v15.8h, v5.8b, v11.8b \n" + "smlal v16.8h, v5.8b, v12.8b \n" + + "dup v9.8b, %16.b[2] \n" + "dup v10.8b, %17.b[2] \n" + "dup v11.8b, %18.b[2] \n" + "dup v12.8b, %19.b[2] \n" + + "smlal v13.8h, v8.8b, v9.8b \n" + "smlal v14.8h, v8.8b, v10.8b \n" + "smlal v15.8h, v8.8b, v11.8b \n" + "smlal v16.8h, v8.8b, v12.8b \n" + // r1 + "prfm pldl1keep, [%6, #128] \n" + "ld2 {v4.8b, v5.8b}, [%6], #16 \n" + "ld2 {v6.8b, v7.8b}, [%6] \n" + "ext v8.8b, v4.8b, v6.8b, #1 \n" + + "dup v9.8b, %16.b[3] \n" + "dup v10.8b, %17.b[3] \n" + "dup v11.8b, %18.b[3] \n" + "dup v12.8b, %19.b[3] \n" + + "smlal v13.8h, v4.8b, v9.8b \n" + "smlal v14.8h, v4.8b, v10.8b \n" + "smlal v15.8h, v4.8b, v11.8b \n" + "smlal v16.8h, v4.8b, v12.8b \n" + + "dup v9.8b, %16.b[4] \n" + "dup v10.8b, %17.b[4] \n" + "dup v11.8b, %18.b[4] \n" + "dup v12.8b, %19.b[4] \n" + + "smlal v13.8h, v5.8b, v9.8b \n" + "smlal v14.8h, v5.8b, v10.8b \n" + "smlal v15.8h, v5.8b, v11.8b \n" + "smlal v16.8h, v5.8b, v12.8b \n" + + "dup v9.8b, %16.b[5] \n" + "dup v10.8b, %17.b[5] \n" + "dup v11.8b, %18.b[5] \n" + "dup v12.8b, %19.b[5] \n" + + "smlal v13.8h, v8.8b, v9.8b \n" + "smlal v14.8h, v8.8b, v10.8b \n" + "smlal v15.8h, v8.8b, v11.8b \n" + "smlal v16.8h, v8.8b, v12.8b \n" + // r2 + "prfm pldl1keep, [%7, #128] \n" + "ld2 {v4.8b, v5.8b}, [%7], #16 \n" + "ld2 {v6.8b, v7.8b}, [%7] \n" + "ext v8.8b, v4.8b, v6.8b, #1 \n" + + "dup v9.8b, %16.b[6] \n" + "dup v10.8b, %17.b[6] \n" + "dup v11.8b, %18.b[6] \n" + "dup v12.8b, %19.b[6] \n" + + "smlal v13.8h, v4.8b, v9.8b \n" + "smlal v14.8h, v4.8b, v10.8b \n" + "smlal v15.8h, v4.8b, v11.8b \n" + "smlal v16.8h, v4.8b, v12.8b \n" + + "dup v9.8b, %16.b[7] \n" + "dup v10.8b, %17.b[7] \n" + "dup v11.8b, %18.b[7] \n" + "dup v12.8b, %19.b[7] \n" + + "smlal v13.8h, v5.8b, v9.8b \n" + "smlal v14.8h, v5.8b, v10.8b \n" + "smlal v15.8h, v5.8b, v11.8b \n" + "smlal v16.8h, v5.8b, v12.8b \n" + + "dup v9.8b, %16.b[8] \n" + "dup v10.8b, %17.b[8] \n" + "dup v11.8b, %18.b[8] \n" + "dup v12.8b, %19.b[8] \n" + + "smlal v13.8h, v8.8b, v9.8b \n" + "smlal v14.8h, v8.8b, v10.8b \n" + "smlal v15.8h, v8.8b, v11.8b \n" + "smlal v16.8h, v8.8b, v12.8b \n" + // sum0 - sum3 + "prfm pldl1keep, [%1, #128] \n" + "prfm pldl1keep, [%2, #128] \n" + "prfm pldl1keep, [%3, #128] \n" + "prfm pldl1keep, [%4, #128] \n" + "ld1 {v17.4s, v18.4s}, [%1] \n" + "ld1 {v19.4s, v20.4s}, [%2] \n" + "ld1 {v21.4s, v22.4s}, [%3] \n" + "ld1 {v23.4s, v24.4s}, [%4] \n" + + "saddw v17.4s, v17.4s, v13.4h \n" + "saddw2 v18.4s, v18.4s, v13.8h \n" + "saddw v19.4s, v19.4s, v14.4h \n" + "saddw2 v20.4s, v20.4s, v14.8h \n" + "saddw v21.4s, v21.4s, v15.4h \n" + "saddw2 v22.4s, v22.4s, v15.8h \n" + "saddw v23.4s, v23.4s, v16.4h \n" + "saddw2 v24.4s, v24.4s, v16.8h \n" + "st1 {v17.4s, v18.4s}, [%1], #32\n" + "st1 {v19.4s, v20.4s}, [%2], #32\n" + "st1 {v21.4s, v22.4s}, [%3], #32\n" + "st1 {v23.4s, v24.4s}, [%4], #32\n" + "subs %w0, %w0, #1 \n" + "bne 0b \n" + : "=r"(nn), //%0 + "=r"(outptr0), //%1 + "=r"(outptr1), //%2 + "=r"(outptr2), //%3 + "=r"(outptr3), //%4 + "=r"(r0), //%5 + "=r"(r1), //%6 + "=r"(r2) //%7 + : "0"(nn), + "1"(outptr0), + "2"(outptr1), + "3"(outptr2), + "4"(outptr3), + "5"(r0), + "6"(r1), + "7"(r2), + "w"(_k0), //%16 + "w"(_k1), //%17 + "w"(_k2), //%18 + "w"(_k3) //%19 + : "cc", "memory", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24" + ); + } + + if (remain >= 4) + { + remain -= 4; + + asm volatile( + // r0 + "prfm pldl1keep, [%5, #128] \n" + "ld2 {v4.8b, v5.8b}, [%5], #16 \n" + "ld2 {v6.8b, v7.8b}, [%5] \n" + "ext v8.8b, v4.8b, v6.8b, #1 \n" + + "dup v9.8b, %16.b[0] \n" + "dup v10.8b, %17.b[0] \n" + "dup v11.8b, %18.b[0] \n" + "dup v12.8b, %19.b[0] \n" + + "smull v13.8h, v4.8b, v9.8b \n" + "smull v14.8h, v4.8b, v10.8b \n" + "smull v15.8h, v4.8b, v11.8b \n" + "smull v16.8h, v4.8b, v12.8b \n" + + "dup v9.8b, %16.b[1] \n" + "dup v10.8b, %17.b[1] \n" + "dup v11.8b, %18.b[1] \n" + "dup v12.8b, %19.b[1] \n" + + "smlal v13.8h, v5.8b, v9.8b \n" + "smlal v14.8h, v5.8b, v10.8b \n" + "smlal v15.8h, v5.8b, v11.8b \n" + "smlal v16.8h, v5.8b, v12.8b \n" + + "dup v9.8b, %16.b[2] \n" + "dup v10.8b, %17.b[2] \n" + "dup v11.8b, %18.b[2] \n" + "dup v12.8b, %19.b[2] \n" + + "smlal v13.8h, v8.8b, v9.8b \n" + "smlal v14.8h, v8.8b, v10.8b \n" + "smlal v15.8h, v8.8b, v11.8b \n" + "smlal v16.8h, v8.8b, v12.8b \n" + // r1 + "prfm pldl1keep, [%6, #128] \n" + "ld2 {v4.8b, v5.8b}, [%6], #16 \n" + "ld2 {v6.8b, v7.8b}, [%6] \n" + "ext v8.8b, v4.8b, v6.8b, #1 \n" + + "dup v9.8b, %16.b[3] \n" + "dup v10.8b, %17.b[3] \n" + "dup v11.8b, %18.b[3] \n" + "dup v12.8b, %19.b[3] \n" + + "smlal v13.8h, v4.8b, v9.8b \n" + "smlal v14.8h, v4.8b, v10.8b \n" + "smlal v15.8h, v4.8b, v11.8b \n" + "smlal v16.8h, v4.8b, v12.8b \n" + + "dup v9.8b, %16.b[4] \n" + "dup v10.8b, %17.b[4] \n" + "dup v11.8b, %18.b[4] \n" + "dup v12.8b, %19.b[4] \n" + + "smlal v13.8h, v5.8b, v9.8b \n" + "smlal v14.8h, v5.8b, v10.8b \n" + "smlal v15.8h, v5.8b, v11.8b \n" + "smlal v16.8h, v5.8b, v12.8b \n" + + "dup v9.8b, %16.b[5] \n" + "dup v10.8b, %17.b[5] \n" + "dup v11.8b, %18.b[5] \n" + "dup v12.8b, %19.b[5] \n" + + "smlal v13.8h, v8.8b, v9.8b \n" + "smlal v14.8h, v8.8b, v10.8b \n" + "smlal v15.8h, v8.8b, v11.8b \n" + "smlal v16.8h, v8.8b, v12.8b \n" + // r2 + "prfm pldl1keep, [%7, #128] \n" + "ld2 {v4.8b, v5.8b}, [%7], #16 \n" + "ld2 {v6.8b, v7.8b}, [%7] \n" + "ext v8.8b, v4.8b, v6.8b, #1 \n" + + "dup v9.8b, %16.b[6] \n" + "dup v10.8b, %17.b[6] \n" + "dup v11.8b, %18.b[6] \n" + "dup v12.8b, %19.b[6] \n" + + "smlal v13.8h, v4.8b, v9.8b \n" + "smlal v14.8h, v4.8b, v10.8b \n" + "smlal v15.8h, v4.8b, v11.8b \n" + "smlal v16.8h, v4.8b, v12.8b \n" + + "dup v9.8b, %16.b[7] \n" + "dup v10.8b, %17.b[7] \n" + "dup v11.8b, %18.b[7] \n" + "dup v12.8b, %19.b[7] \n" + + "smlal v13.8h, v5.8b, v9.8b \n" + "smlal v14.8h, v5.8b, v10.8b \n" + "smlal v15.8h, v5.8b, v11.8b \n" + "smlal v16.8h, v5.8b, v12.8b \n" + + "dup v9.8b, %16.b[8] \n" + "dup v10.8b, %17.b[8] \n" + "dup v11.8b, %18.b[8] \n" + "dup v12.8b, %19.b[8] \n" + + "smlal v13.8h, v8.8b, v9.8b \n" + "smlal v14.8h, v8.8b, v10.8b \n" + "smlal v15.8h, v8.8b, v11.8b \n" + "smlal v16.8h, v8.8b, v12.8b \n" + // sum0 - sum3 + "prfm pldl1keep, [%1, #128] \n" + "prfm pldl1keep, [%2, #128] \n" + "prfm pldl1keep, [%3, #128] \n" + "prfm pldl1keep, [%4, #128] \n" + "ld1 {v17.4s}, [%1] \n" + "ld1 {v19.4s}, [%2] \n" + "ld1 {v21.4s}, [%3] \n" + "ld1 {v23.4s}, [%4] \n" + + "saddw v17.4s, v17.4s, v13.4h \n" + "saddw v19.4s, v19.4s, v14.4h \n" + "saddw v21.4s, v21.4s, v15.4h \n" + "saddw v23.4s, v23.4s, v16.4h \n" + + "st1 {v17.4s}, [%1], #16 \n" + "st1 {v19.4s}, [%2], #16 \n" + "st1 {v21.4s}, [%3], #16 \n" + "st1 {v23.4s}, [%4], #16 \n" + "sub %5, %5, #8 \n" + "sub %6, %6, #8 \n" + "sub %7, %7, #8 \n" + : "=r"(nn), //%0 + "=r"(outptr0), //%1 + "=r"(outptr1), //%2 + "=r"(outptr2), //%3 + "=r"(outptr3), //%4 + "=r"(r0), //%5 + "=r"(r1), //%6 + "=r"(r2) //%7 + : "0"(nn), + "1"(outptr0), + "2"(outptr1), + "3"(outptr2), + "4"(outptr3), + "5"(r0), + "6"(r1), + "7"(r2), + "w"(_k0), //%16 + "w"(_k1), //%17 + "w"(_k2), //%18 + "w"(_k3) //%19 + : "cc", "memory", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24" + ); + } + + for (; remain>0; remain--) { int sum0 = 0; + int sum1 = 0; + int sum2 = 0; + int sum3 = 0; + + sum0 += (int)r0[0] * kernel0[0]; + sum0 += (int)r0[1] * kernel0[1]; + sum0 += (int)r0[2] * kernel0[2]; + sum0 += (int)r1[0] * kernel0[3]; + sum0 += (int)r1[1] * kernel0[4]; + sum0 += (int)r1[2] * kernel0[5]; + sum0 += (int)r2[0] * kernel0[6]; + sum0 += (int)r2[1] * kernel0[7]; + sum0 += (int)r2[2] * kernel0[8]; + + sum1 += (int)r0[0] * kernel1[0]; + sum1 += (int)r0[1] * kernel1[1]; + sum1 += (int)r0[2] * kernel1[2]; + sum1 += (int)r1[0] * kernel1[3]; + sum1 += (int)r1[1] * kernel1[4]; + sum1 += (int)r1[2] * kernel1[5]; + sum1 += (int)r2[0] * kernel1[6]; + sum1 += (int)r2[1] * kernel1[7]; + sum1 += (int)r2[2] * kernel1[8]; - sum0 += (int)r0[0] * (int)kernel0[0]; - sum0 += (int)r0[1] * (int)kernel0[1]; - sum0 += (int)r0[2] * (int)kernel0[2]; - sum0 += (int)r1[0] * (int)kernel0[3]; - sum0 += (int)r1[1] * (int)kernel0[4]; - sum0 += (int)r1[2] * (int)kernel0[5]; - sum0 += (int)r2[0] * (int)kernel0[6]; - sum0 += (int)r2[1] * (int)kernel0[7]; - sum0 += (int)r2[2] * (int)kernel0[8]; + sum2 += (int)r0[0] * kernel2[0]; + sum2 += (int)r0[1] * kernel2[1]; + sum2 += (int)r0[2] * kernel2[2]; + sum2 += (int)r1[0] * kernel2[3]; + sum2 += (int)r1[1] * kernel2[4]; + sum2 += (int)r1[2] * kernel2[5]; + sum2 += (int)r2[0] * kernel2[6]; + sum2 += (int)r2[1] * kernel2[7]; + sum2 += (int)r2[2] * kernel2[8]; + + sum3 += (int)r0[0] * kernel3[0]; + sum3 += (int)r0[1] * kernel3[1]; + sum3 += (int)r0[2] * kernel3[2]; + sum3 += (int)r1[0] * kernel3[3]; + sum3 += (int)r1[1] * kernel3[4]; + sum3 += (int)r1[2] * kernel3[5]; + sum3 += (int)r2[0] * kernel3[6]; + sum3 += (int)r2[1] * kernel3[7]; + sum3 += (int)r2[2] * kernel3[8]; *outptr0 += sum0; + *outptr1 += sum1; + *outptr2 += sum2; + *outptr3 += sum3; r0 += 2; r1 += 2; r2 += 2; outptr0++; - } + outptr1++; + outptr2++; + outptr3++; + } r0 += tailstep; r1 += tailstep; @@ -202,8 +1117,167 @@ static void conv3x3s2_int8_neon(const Mat &bottom_blob, Mat &top_blob, const Mat } kernel0 += 9; + kernel1 += 9; + kernel2 += 9; + kernel3 += 9; } } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p=remain_outch_start; p> 3; + int remain = outw & 7; + + remain = outw; + + for (; nn >0; nn--) + { + int8x8x2_t _r0 = vld2_s8(r0); + int8x8x2_t _r0n = vld2_s8(r0+16); + int8x8_t _r00 = _r0.val[0]; + int8x8_t _r01 = _r0.val[1]; + int8x8_t _r02 = vext_s8(_r00, _r0n.val[0], 1); + + int16x8_t _sum = vmull_s8(_r00, _k0); + _sum = vmlal_s8(_sum, _r01, _k1); + _sum = vmlal_s8(_sum, _r02, _k2); + + int8x8x2_t _r1 = vld2_s8(r1); + int8x8x2_t _r1n = vld2_s8(r1+16); + int8x8_t _r10 = _r1.val[0]; + int8x8_t _r11 = _r1.val[1]; + int8x8_t _r12 = vext_s8(_r10, _r1n.val[0], 1); + _sum = vmlal_s8(_sum, _r10, _k3); + _sum = vmlal_s8(_sum, _r11, _k4); + _sum = vmlal_s8(_sum, _r12, _k5); + + int8x8x2_t _r2 = vld2_s8(r2); + int8x8x2_t _r2n = vld2_s8(r2+16); + int8x8_t _r20 = _r2.val[0]; + int8x8_t _r21 = _r2.val[1]; + int8x8_t _r22 = vext_s8(_r20, _r2n.val[0], 1); + _sum = vmlal_s8(_sum, _r20, _k6); + _sum = vmlal_s8(_sum, _r21, _k7); + _sum = vmlal_s8(_sum, _r22, _k8); + + int32x4_t sum0_s32 = vld1q_s32(outptr0); + int32x4_t sum0n_s32 = vld1q_s32(outptr0+4); + + sum0_s32 = vaddw_s16(sum0_s32, vget_low_s16(_sum)); + sum0n_s32 = vaddw_s16(sum0n_s32, vget_high_s16(_sum)); + + vst1q_s32(outptr0, sum0_s32); + vst1q_s32(outptr0+4, sum0n_s32); + + r0 += 16; + r1 += 16; + r2 += 16; + outptr0 += 8; + } + + if (remain >= 4) + { + remain -= 4; + + int8x8x2_t _r0 = vld2_s8(r0); + int8x8x2_t _r0n = vld2_s8(r0+16); + int8x8_t _r00 = _r0.val[0]; + int8x8_t _r01 = _r0.val[1]; + int8x8_t _r02 = vext_s8(_r00, _r0n.val[0], 1); + + int16x8_t _sum = vmull_s8(_r00, _k0); + _sum = vmlal_s8(_sum, _r01, _k1); + _sum = vmlal_s8(_sum, _r02, _k2); + + int8x8x2_t _r1 = vld2_s8(r1); + int8x8x2_t _r1n = vld2_s8(r1+16); + int8x8_t _r10 = _r1.val[0]; + int8x8_t _r11 = _r1.val[1]; + int8x8_t _r12 = vext_s8(_r10, _r1n.val[0], 1); + _sum = vmlal_s8(_sum, _r10, _k3); + _sum = vmlal_s8(_sum, _r11, _k4); + _sum = vmlal_s8(_sum, _r12, _k5); + + int8x8x2_t _r2 = vld2_s8(r2); + int8x8x2_t _r2n = vld2_s8(r2+16); + int8x8_t _r20 = _r2.val[0]; + int8x8_t _r21 = _r2.val[1]; + int8x8_t _r22 = vext_s8(_r20, _r2n.val[0], 1); + _sum = vmlal_s8(_sum, _r20, _k6); + _sum = vmlal_s8(_sum, _r21, _k7); + _sum = vmlal_s8(_sum, _r22, _k8); + + int32x4_t sum0_s32 = vld1q_s32(outptr0); + sum0_s32 = vaddw_s16(sum0_s32, vget_low_s16(_sum)); + vst1q_s32(outptr0, sum0_s32); + + r0 += 8; + r1 += 8; + r2 += 8; + outptr0 += 4; + } + + for (; remain>0; remain--) + { + int sum0 = 0; + + sum0 += (int)r0[0] * kernel0[0]; + sum0 += (int)r0[1] * kernel0[1]; + sum0 += (int)r0[2] * kernel0[2]; + sum0 += (int)r1[0] * kernel0[3]; + sum0 += (int)r1[1] * kernel0[4]; + sum0 += (int)r1[2] * kernel0[5]; + sum0 += (int)r2[0] * kernel0[6]; + sum0 += (int)r2[1] * kernel0[7]; + sum0 += (int)r2[2] * kernel0[8]; + + *outptr0 += sum0; + + r0 += 2; + r1 += 2; + r2 += 2; + outptr0++; + } + + r0 += tailstep; + r1 += tailstep; + r2 += tailstep; + } + + kernel0 += 9; + } + } } #else // __aarch64__ static void conv3x3s1_neon_s8(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Option& opt) diff --git a/src/layer/arm/convolutiondepthwise_3x3_int8.h b/src/layer/arm/convolutiondepthwise_3x3_int8.h index 5da7f6bb9..f00325850 100644 --- a/src/layer/arm/convolutiondepthwise_3x3_int8.h +++ b/src/layer/arm/convolutiondepthwise_3x3_int8.h @@ -30,42 +30,212 @@ static void convdw3x3s1_int8_neon(const Mat &bottom_blob, Mat &top_blob, const M { Mat out = top_blob.channel(p); - const signed char *kernel0 = (const signed char *)_kernel + p * 9; + const signed char* kernel = (const signed char *)_kernel + p*9; + + int* outptr0 = out; + int* outptr0n = outptr0 + outw; + + const signed char* img0 = bottom_blob.channel(p); + + const signed char* r0 = img0; + const signed char* r1 = img0 + w; + const signed char* r2 = img0 + w*2; + const signed char* r3 = img0 + w*3; + + int i = 0; + + int8x8_t _k0 = vdup_n_s8(kernel[0]); + int8x8_t _k1 = vdup_n_s8(kernel[1]); + int8x8_t _k2 = vdup_n_s8(kernel[2]); + + int8x8_t _k3 = vdup_n_s8(kernel[3]); + int8x8_t _k4 = vdup_n_s8(kernel[4]); + int8x8_t _k5 = vdup_n_s8(kernel[5]); - int *outptr = out; + int8x8_t _k6 = vdup_n_s8(kernel[6]); + int8x8_t _k7 = vdup_n_s8(kernel[7]); + int8x8_t _k8 = vdup_n_s8(kernel[8]); - const signed char *img0 = bottom_blob.channel(p); + for (; i+1 < outh; i+=2) + { + int nn = outw >> 3; + int remain = outw & 7; - const signed char *r0 = img0; - const signed char *r1 = img0 + w; - const signed char *r2 = img0 + w * 2; + for (; nn >0; nn--) + { + int8x8_t _r0 = vld1_s8(r0); + int8x8_t _r0n = vld1_s8(r0+8); + int8x8_t _r01 = vext_s8(_r0, _r0n, 1); + int8x8_t _r02 = vext_s8(_r0, _r0n, 2); + + int16x8_t _sum0 = vmull_s8(_r0, _k0); + _sum0 = vmlal_s8(_sum0, _r01, _k1); + _sum0 = vmlal_s8(_sum0, _r02, _k2); + + int8x8_t _r1 = vld1_s8(r1); + int8x8_t _r1n = vld1_s8(r1+8); + int8x8_t _r11 = vext_s8(_r1, _r1n, 1); + int8x8_t _r12 = vext_s8(_r1, _r1n, 2); + _sum0 = vmlal_s8(_sum0, _r1, _k3); + _sum0 = vmlal_s8(_sum0, _r11, _k4); + _sum0 = vmlal_s8(_sum0, _r12, _k5); + + int16x8_t _sum1 = vmull_s8(_r1, _k0); + _sum1 = vmlal_s8(_sum1, _r11, _k1); + _sum1 = vmlal_s8(_sum1, _r12, _k2); + + int8x8_t _r2 = vld1_s8(r2); + int8x8_t _r2n = vld1_s8(r2+8); + int8x8_t _r21 = vext_s8(_r2, _r2n, 1); + int8x8_t _r22 = vext_s8(_r2, _r2n, 2); + _sum0 = vmlal_s8(_sum0, _r2, _k6); + _sum0 = vmlal_s8(_sum0, _r21, _k7); + _sum0 = vmlal_s8(_sum0, _r22, _k8); + + _sum1 = vmlal_s8(_sum1, _r2, _k3); + _sum1 = vmlal_s8(_sum1, _r21, _k4); + _sum1 = vmlal_s8(_sum1, _r22, _k5); + + int8x8_t _r3 = vld1_s8(r3); + int8x8_t _r3n = vld1_s8(r3+8); + int8x8_t _r31 = vext_s8(_r3, _r3n, 1); + int8x8_t _r32 = vext_s8(_r3, _r3n, 2); + _sum1 = vmlal_s8(_sum1, _r3, _k6); + _sum1 = vmlal_s8(_sum1, _r31, _k7); + _sum1 = vmlal_s8(_sum1, _r32, _k8); + + int32x4_t sum0_s32 = vmovl_s16(vget_low_s16(_sum0)); + int32x4_t sum0n_s32 = vmovl_s16(vget_high_s16(_sum0)); + + vst1q_s32(outptr0, sum0_s32); + vst1q_s32(outptr0+4, sum0n_s32); + + int32x4_t sum1_s32 = vmovl_s16(vget_low_s16(_sum1)); + int32x4_t sum1n_s32 = vmovl_s16(vget_high_s16(_sum1)); + + vst1q_s32(outptr0n, sum1_s32); + vst1q_s32(outptr0n+4, sum1n_s32); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + outptr0 += 8; + outptr0n += 8; + } + + for (; remain>0; remain--) + { + //Todo Neon + + int sum0 = 0; + int sum0n = 0; + + sum0 += (int)r0[0] * kernel[0]; + sum0 += (int)r0[1] * kernel[1]; + sum0 += (int)r0[2] * kernel[2]; + sum0 += (int)r1[0] * kernel[3]; + sum0 += (int)r1[1] * kernel[4]; + sum0 += (int)r1[2] * kernel[5]; + sum0 += (int)r2[0] * kernel[6]; + sum0 += (int)r2[1] * kernel[7]; + sum0 += (int)r2[2] * kernel[8]; + + sum0n += (int)r1[0] * kernel[0]; + sum0n += (int)r1[1] * kernel[1]; + sum0n += (int)r1[2] * kernel[2]; + sum0n += (int)r2[0] * kernel[3]; + sum0n += (int)r2[1] * kernel[4]; + sum0n += (int)r2[2] * kernel[5]; + sum0n += (int)r3[0] * kernel[6]; + sum0n += (int)r3[1] * kernel[7]; + sum0n += (int)r3[2] * kernel[8]; + + *outptr0 = sum0; + *outptr0n = sum0n; + + r0++; + r1++; + r2++; + r3++; + outptr0++; + outptr0n++; + } + + r0 += 2 + w; + r1 += 2 + w; + r2 += 2 + w; + r3 += 2 + w; + + outptr0 += outw; + outptr0n += outw; + } - int i = 0; for (; i < outh; i++) { - int remain = outw; + int nn = outw >> 3; + int remain = outw & 7; + + for (; nn >0; nn--) + { + int8x8_t _r0 = vld1_s8(r0); + int8x8_t _r0n = vld1_s8(r0+8); + int8x8_t _r01 = vext_s8(_r0, _r0n, 1); + int8x8_t _r02 = vext_s8(_r0, _r0n, 2); + + int16x8_t _sum0 = vmull_s8(_r0, _k0); + _sum0 = vmlal_s8(_sum0, _r01, _k1); + _sum0 = vmlal_s8(_sum0, _r02, _k2); + + int8x8_t _r1 = vld1_s8(r1); + int8x8_t _r1n = vld1_s8(r1+8); + int8x8_t _r11 = vext_s8(_r1, _r1n, 1); + int8x8_t _r12 = vext_s8(_r1, _r1n, 2); + _sum0 = vmlal_s8(_sum0, _r1, _k3); + _sum0 = vmlal_s8(_sum0, _r11, _k4); + _sum0 = vmlal_s8(_sum0, _r12, _k5); + + int8x8_t _r2 = vld1_s8(r2); + int8x8_t _r2n = vld1_s8(r2+8); + int8x8_t _r21 = vext_s8(_r2, _r2n, 1); + int8x8_t _r22 = vext_s8(_r2, _r2n, 2); + _sum0 = vmlal_s8(_sum0, _r2, _k6); + _sum0 = vmlal_s8(_sum0, _r21, _k7); + _sum0 = vmlal_s8(_sum0, _r22, _k8); + + int32x4_t sum0_s32 = vmovl_s16(vget_low_s16(_sum0)); + int32x4_t sum0n_s32 = vmovl_s16(vget_high_s16(_sum0)); + + vst1q_s32(outptr0, sum0_s32); + vst1q_s32(outptr0+4, sum0n_s32); + + r0 += 8; + r1 += 8; + r2 += 8; + outptr0 += 8; + } - for (; remain > 0; remain--) + for (; remain>0; remain--) { int sum = 0; - sum += (int)r0[0] * (int)kernel0[0]; - sum += (int)r0[1] * (int)kernel0[1]; - sum += (int)r0[2] * (int)kernel0[2]; - sum += (int)r1[0] * (int)kernel0[3]; - sum += (int)r1[1] * (int)kernel0[4]; - sum += (int)r1[2] * (int)kernel0[5]; - sum += (int)r2[0] * (int)kernel0[6]; - sum += (int)r2[1] * (int)kernel0[7]; - sum += (int)r2[2] * (int)kernel0[8]; + sum += (int)r0[0] * kernel[0]; + sum += (int)r0[1] * kernel[1]; + sum += (int)r0[2] * kernel[2]; + sum += (int)r1[0] * kernel[3]; + sum += (int)r1[1] * kernel[4]; + sum += (int)r1[2] * kernel[5]; + sum += (int)r2[0] * kernel[6]; + sum += (int)r2[1] * kernel[7]; + sum += (int)r2[2] * kernel[8]; - *outptr = sum; + *outptr0 = sum; r0++; r1++; r2++; - outptr++; - } + outptr0++; + } r0 += 2; r1 += 2; @@ -82,42 +252,95 @@ static void convdw3x3s2_int8_neon(const Mat &bottom_blob, Mat &top_blob, const M int outh = top_blob.h; int outch = top_blob.c; - const int tailstep = w - 2 * outw + w; + const int tailstep = w - 2*outw + w; #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) + for (int p=0; p> 3; + int remain = outw & 7; - for (; remain > 0; remain--) + for (; nn > 0; nn--) { - int sum = 0; + int8x8x2_t _r0 = vld2_s8(r0); + int8x8x2_t _r0n = vld2_s8(r0+16); + int8x8_t _r00 = _r0.val[0]; + int8x8_t _r01 = _r0.val[1]; + int8x8_t _r02 = vext_s8(_r00, _r0n.val[0], 1); + + int16x8_t _sum = vmull_s8(_r00, _k0); + _sum = vmlal_s8(_sum, _r01, _k1); + _sum = vmlal_s8(_sum, _r02, _k2); + + int8x8x2_t _r1 = vld2_s8(r1); + int8x8x2_t _r1n = vld2_s8(r1+16); + int8x8_t _r10 = _r1.val[0]; + int8x8_t _r11 = _r1.val[1]; + int8x8_t _r12 = vext_s8(_r10, _r1n.val[0], 1); + _sum = vmlal_s8(_sum, _r10, _k3); + _sum = vmlal_s8(_sum, _r11, _k4); + _sum = vmlal_s8(_sum, _r12, _k5); + + int8x8x2_t _r2 = vld2_s8(r2); + int8x8x2_t _r2n = vld2_s8(r2+16); + int8x8_t _r20 = _r2.val[0]; + int8x8_t _r21 = _r2.val[1]; + int8x8_t _r22 = vext_s8(_r20, _r2n.val[0], 1); + _sum = vmlal_s8(_sum, _r20, _k6); + _sum = vmlal_s8(_sum, _r21, _k7); + _sum = vmlal_s8(_sum, _r22, _k8); + + int32x4_t sum0_s32 = vmovl_s16(vget_low_s16(_sum)); + int32x4_t sum0n_s32 = vmovl_s16(vget_high_s16(_sum)); + + vst1q_s32(outptr, sum0_s32); + vst1q_s32(outptr+4, sum0n_s32); + + r0 += 16; + r1 += 16; + r2 += 16; + outptr += 8; + } - sum += (int)r0[0] * (int)kernel0[0]; - sum += (int)r0[1] * (int)kernel0[1]; - sum += (int)r0[2] * (int)kernel0[2]; - sum += (int)r1[0] * (int)kernel0[3]; - sum += (int)r1[1] * (int)kernel0[4]; - sum += (int)r1[2] * (int)kernel0[5]; - sum += (int)r2[0] * (int)kernel0[6]; - sum += (int)r2[1] * (int)kernel0[7]; - sum += (int)r2[2] * (int)kernel0[8]; + for (; remain>0; remain--) + { + int sum = 0; + + sum += (int)r0[0] * kernel[0]; + sum += (int)r0[1] * kernel[1]; + sum += (int)r0[2] * kernel[2]; + sum += (int)r1[0] * kernel[3]; + sum += (int)r1[1] * kernel[4]; + sum += (int)r1[2] * kernel[5]; + sum += (int)r2[0] * kernel[6]; + sum += (int)r2[1] * kernel[7]; + sum += (int)r2[2] * kernel[8]; *outptr = sum; diff --git a/src/layer/arm/dequantize_arm.cpp b/src/layer/arm/dequantize_arm.cpp index 0ccd1e332..d98f763eb 100644 --- a/src/layer/arm/dequantize_arm.cpp +++ b/src/layer/arm/dequantize_arm.cpp @@ -21,11 +21,6 @@ DEFINE_LAYER_CREATOR(Dequantize_arm) int Dequantize_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { -#if __aarch64__ - // TODO port to aarch64 - return Dequantize::forward_inplace(bottom_top_blob, opt); -#endif // __aarch64__ - int dims = bottom_top_blob.dims; if (dims == 1) @@ -116,7 +111,41 @@ int Dequantize_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) con #if __ARM_NEON #if __aarch64__ - // TODO + float32x4_t _bias = vdupq_n_f32(bias); + float32x4_t _scale = vdupq_n_f32(scale); + + if (nn > 0) + { + asm volatile( + "dup v2.4s, %w6 \n" // scale + "dup v3.4s, %w7 \n" // bias + "0: \n" + "prfm pldl1keep, [%1, #128] \n" + "ld1 {v0.4s, v1.4s}, [%1], #32 \n" // data + // top_s32 -> top_f32 + "scvtf v5.4s, v0.4s \n" + "scvtf v6.4s, v1.4s \n" + // top_f32 = top_f32 * scale_out + "fmul v5.4s, v5.4s, v2.4s \n" + "fmul v6.4s, v6.4s, v2.4s \n" + // top_f32 = top_f32 + bias_tm + "fadd v5.4s, v5.4s, v3.4s \n" + "fadd v6.4s, v6.4s, v3.4s \n" + // save top_f32 + "st1 {v5.4s, v6.4s}, [%2], #32 \n" + "subs %w0, %w0, #1 \n" + "bne 0b \n" + : "=r"(nn), // %0 + "=r"(intptr), // %1 + "=r"(ptr) // %2 + : "0"(nn), + "1"(intptr), + "2"(ptr), + "r"(_scale), // %6 + "r"(_bias) // %7 + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6" + ); + } #else if (nn > 0) { @@ -127,8 +156,8 @@ int Dequantize_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) con "vdup.f32 q12, %7 \n" //q12 bias "0: \n" - "vcvt.f32.s32 q0, q0 \n" - "vcvt.f32.s32 q1, q1 \n" + "vcvt.f32.s32 q0, q0 \n" + "vcvt.f32.s32 q1, q1 \n" "vmul.f32 q0,q0,q10 \n" "vmul.f32 q1,q1,q10 \n" @@ -183,7 +212,35 @@ int Dequantize_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) con #if __ARM_NEON #if __aarch64__ - // TODO + float32x4_t _scale = vdupq_n_f32(scale); + + if (nn > 0) + { + asm volatile( + "dup v2.4s, %w6 \n" // scale + "0: \n" + "prfm pldl1keep, [%1, #128] \n" + "ld1 {v0.4s, v1.4s}, [%1], #32 \n" // data + // top_s32 -> top_f32 + "scvtf v5.4s, v0.4s \n" + "scvtf v6.4s, v1.4s \n" + // top_f32 = top_f32 * scale_out + "fmul v5.4s, v5.4s, v2.4s \n" + "fmul v6.4s, v6.4s, v2.4s \n" + // save top_f32 + "st1 {v5.4s, v6.4s}, [%2], #32 \n" + "subs %w0, %w0, #1 \n" + "bne 0b \n" + : "=r"(nn), // %0 + "=r"(intptr), // %1 + "=r"(ptr) // %2 + : "0"(nn), + "1"(intptr), + "2"(ptr), + "r"(_scale) // %6 + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6" + ); + } #else if (nn > 0) { @@ -193,8 +250,8 @@ int Dequantize_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) con "vdup.f32 q10, %6 \n" //q10 scale "0: \n" - "vcvt.f32.s32 q0, q0 \n" - "vcvt.f32.s32 q1, q1 \n" + "vcvt.f32.s32 q0, q0 \n" + "vcvt.f32.s32 q1, q1 \n" "vmul.f32 q2,q0,q10 \n" "vmul.f32 q3,q1,q10 \n" diff --git a/src/layer/arm/quantize_arm.cpp b/src/layer/arm/quantize_arm.cpp index cacb5d08c..df8a829fd 100644 --- a/src/layer/arm/quantize_arm.cpp +++ b/src/layer/arm/quantize_arm.cpp @@ -31,11 +31,6 @@ static inline signed char float2int8(float v) int Quantize_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { -#if __aarch64__ - // TODO port to aarch64 fcvtas - return Quantize::forward(bottom_blob, top_blob, opt); -#endif // __aarch64__ - #if !__aarch64__ int FPSCR_value = 0; @@ -115,7 +110,40 @@ int Quantize_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& o #if __ARM_NEON #if __aarch64__ - // TODO + float32x4_t _scale = vdupq_n_f32(scale); + + if (nn > 0) + { + asm volatile( + "dup v2.4s, %w6 \n" //scale + "0: \n" + "prfm pldl1keep, [%1, #128] \n" + "ld1 {v0.4s, v1.4s}, [%1], #32 \n" //data + // bottom_f32 = bottom_f32 * scale + "fmul v3.4s, v0.4s, v2.4s \n" + "fmul v4.4s, v1.4s, v2.4s \n" + // top_f32 -> top_s32 + "fcvtas v5.4s, v3.4s \n" + "fcvtas v6.4s, v4.4s \n" + // top_s32 -> top_s16 + "sqxtn v7.4h, v5.4s \n" + "sqxtn2 v7.8h, v6.4s \n" + // top_s16 -> top_s8 + "sqxtn v8.8b, v7.8h \n" + // save top_s8 + "st1 {v8.8b}, [%2], #8 \n" + "subs %w0, %w0, #1 \n" + "bne 0b \n" + : "=r"(nn), // %0 + "=r"(ptr), // %1 + "=r"(outptr) // %2 + : "0"(nn), + "1"(ptr), + "2"(outptr), + "r"(_scale) // %6 + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8" + ); + } #else if (nn > 0) {