From 68468dccbd842c716345c7d44fc7d634fe901c8a Mon Sep 17 00:00:00 2001 From: nihuini Date: Fri, 16 Apr 2021 18:56:18 +0800 Subject: [PATCH] arm neon assembly optimization for padding int8 pack8, convolution int8 out elempack 4 --- ...int8.h => convolution_1x1_pack1to4_int8.h} | 8 +- ...int8.h => convolution_1x1_pack8to4_int8.h} | 8 +- ...int8.h => convolution_3x3_pack1to4_int8.h} | 8 +- ...int8.h => convolution_3x3_pack8to4_int8.h} | 1580 ++++++++++++----- ...int8.h => convolution_7x7_pack1to4_int8.h} | 4 +- src/layer/arm/convolution_arm.cpp | 141 +- ...to8_int8.h => convolution_pack1to4_int8.h} | 9 +- ...ck8_int8.h => convolution_pack8to4_int8.h} | 26 +- ...t8.h => convolution_sgemm_pack1to4_int8.h} | 8 +- ...t8.h => convolution_sgemm_pack8to4_int8.h} | 8 +- src/layer/arm/convolutiondepthwise_arm.cpp | 62 +- src/layer/arm/padding_pack8_int8.h | 358 +++- 12 files changed, 1521 insertions(+), 699 deletions(-) rename src/layer/arm/{convolution_1x1_pack1to8_int8.h => convolution_1x1_pack1to4_int8.h} (90%) rename src/layer/arm/{convolution_1x1_pack8_int8.h => convolution_1x1_pack8to4_int8.h} (86%) rename src/layer/arm/{convolution_3x3_pack1to8_int8.h => convolution_3x3_pack1to4_int8.h} (94%) rename src/layer/arm/{convolution_3x3_pack8_int8.h => convolution_3x3_pack8to4_int8.h} (53%) rename src/layer/arm/{convolution_7x7_pack1to8_int8.h => convolution_7x7_pack1to4_int8.h} (95%) rename src/layer/arm/{convolution_pack1to8_int8.h => convolution_pack1to4_int8.h} (89%) rename src/layer/arm/{convolution_pack8_int8.h => convolution_pack8to4_int8.h} (67%) rename src/layer/arm/{convolution_sgemm_pack1to8_int8.h => convolution_sgemm_pack1to4_int8.h} (99%) rename src/layer/arm/{convolution_sgemm_pack8_int8.h => convolution_sgemm_pack8to4_int8.h} (98%) diff --git a/src/layer/arm/convolution_1x1_pack1to8_int8.h b/src/layer/arm/convolution_1x1_pack1to4_int8.h similarity index 90% rename from src/layer/arm/convolution_1x1_pack1to8_int8.h rename to src/layer/arm/convolution_1x1_pack1to4_int8.h index 98db60333..049d3ed3d 100644 --- a/src/layer/arm/convolution_1x1_pack1to8_int8.h +++ b/src/layer/arm/convolution_1x1_pack1to4_int8.h @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -static void conv1x1s1_sgemm_pack1to8_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +static void conv1x1s1_sgemm_pack1to4_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -22,10 +22,10 @@ static void conv1x1s1_sgemm_pack1to8_int8_neon(const Mat& bottom_blob, Mat& top_ bottom_im2col.w = size; bottom_im2col.h = 1; - im2col_sgemm_pack1to8_int8_neon(bottom_im2col, top_blob, kernel, opt); + im2col_sgemm_pack1to4_int8_neon(bottom_im2col, top_blob, kernel, opt); } -static void conv1x1s2_pack1to8_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +static void conv1x1s2_pack1to4_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) { int w = bottom_blob.w; int channels = bottom_blob.c; @@ -79,5 +79,5 @@ static void conv1x1s2_pack1to8_int8_neon(const Mat& bottom_blob, Mat& top_blob, } } - conv1x1s1_sgemm_pack1to8_int8_neon(bottom_blob_shrinked, top_blob, kernel, opt); + conv1x1s1_sgemm_pack1to4_int8_neon(bottom_blob_shrinked, top_blob, kernel, opt); } diff --git a/src/layer/arm/convolution_1x1_pack8_int8.h b/src/layer/arm/convolution_1x1_pack8to4_int8.h similarity index 86% rename from src/layer/arm/convolution_1x1_pack8_int8.h rename to src/layer/arm/convolution_1x1_pack8to4_int8.h index c273c5c7f..a99a9438f 100644 --- a/src/layer/arm/convolution_1x1_pack8_int8.h +++ b/src/layer/arm/convolution_1x1_pack8to4_int8.h @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -static void conv1x1s1_sgemm_pack8_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +static void conv1x1s1_sgemm_pack8to4_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -22,10 +22,10 @@ static void conv1x1s1_sgemm_pack8_int8_neon(const Mat& bottom_blob, Mat& top_blo bottom_im2col.w = size; bottom_im2col.h = 1; - im2col_sgemm_pack8_int8_neon(bottom_im2col, top_blob, kernel, opt); + im2col_sgemm_pack8to4_int8_neon(bottom_im2col, top_blob, kernel, opt); } -static void conv1x1s2_pack8_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +static void conv1x1s2_pack8to4_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) { int w = bottom_blob.w; int channels = bottom_blob.c; @@ -86,5 +86,5 @@ static void conv1x1s2_pack8_int8_neon(const Mat& bottom_blob, Mat& top_blob, con } } - conv1x1s1_sgemm_pack8_int8_neon(bottom_blob_shrinked, top_blob, kernel, opt); + conv1x1s1_sgemm_pack8to4_int8_neon(bottom_blob_shrinked, top_blob, kernel, opt); } diff --git a/src/layer/arm/convolution_3x3_pack1to8_int8.h b/src/layer/arm/convolution_3x3_pack1to4_int8.h similarity index 94% rename from src/layer/arm/convolution_3x3_pack1to8_int8.h rename to src/layer/arm/convolution_3x3_pack1to4_int8.h index fada4297c..8dbac7e44 100644 --- a/src/layer/arm/convolution_3x3_pack1to8_int8.h +++ b/src/layer/arm/convolution_3x3_pack1to4_int8.h @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -static void conv3x3s1_pack1to8_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +static void conv3x3s1_pack1to4_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) { int w = bottom_blob.w; int inch = bottom_blob.c; @@ -76,10 +76,10 @@ static void conv3x3s1_pack1to8_int8_neon(const Mat& bottom_blob, Mat& top_blob, } } - im2col_sgemm_pack1to8_int8_neon(bottom_im2col, top_blob, kernel, opt); + im2col_sgemm_pack1to4_int8_neon(bottom_im2col, top_blob, kernel, opt); } -static void conv3x3s2_pack1to8_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +static void conv3x3s2_pack1to4_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) { int w = bottom_blob.w; int inch = bottom_blob.c; @@ -143,5 +143,5 @@ static void conv3x3s2_pack1to8_int8_neon(const Mat& bottom_blob, Mat& top_blob, } } - im2col_sgemm_pack1to8_int8_neon(bottom_im2col, top_blob, kernel, opt); + im2col_sgemm_pack1to4_int8_neon(bottom_im2col, top_blob, kernel, opt); } diff --git a/src/layer/arm/convolution_3x3_pack8_int8.h b/src/layer/arm/convolution_3x3_pack8to4_int8.h similarity index 53% rename from src/layer/arm/convolution_3x3_pack8_int8.h rename to src/layer/arm/convolution_3x3_pack8to4_int8.h index d93dea024..d234e08b3 100644 --- a/src/layer/arm/convolution_3x3_pack8_int8.h +++ b/src/layer/arm/convolution_3x3_pack8to4_int8.h @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -static void conv3x3s1_winograd42_transform_kernel_pack8_int8_neon(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch) +static void conv3x3s1_winograd42_transform_kernel_pack8to4_int8_neon(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch) { // winograd42 transform kernel Mat kernel_tm(6 * 6, inch, outch, 2u); @@ -63,8 +63,8 @@ static void conv3x3s1_winograd42_transform_kernel_pack8_int8_neon(const Mat& ker // interleave // src = 36-inch-outch - // dst = 8b-8a-inch/8a-36-outch/8b - kernel_tm_pack8.create(inch / 8, 36, outch / 8, (size_t)2u * 64, 64); + // dst = 4b-8a-inch/8a-36-outch/4b + kernel_tm_pack8.create(inch / 8, 36, outch / 8 + (outch % 8) / 4, (size_t)2u * 64, 64); int q = 0; for (; q + 7 < outch; q += 8) @@ -78,11 +78,11 @@ static void conv3x3s1_winograd42_transform_kernel_pack8_int8_neon(const Mat& ker const Mat k6 = kernel_tm.channel(q + 6); const Mat k7 = kernel_tm.channel(q + 7); - Mat g0 = kernel_tm_pack8.channel(q / 8); + Mat kernel_tm = kernel_tm_pack8.channel(q / 8); for (int k = 0; k < 36; k++) { - short* g00 = g0.row(k); + short* g00 = kernel_tm.row(k); for (int p = 0; p + 7 < inch; p += 8) { @@ -111,9 +111,41 @@ static void conv3x3s1_winograd42_transform_kernel_pack8_int8_neon(const Mat& ker } } } + for (; q + 3 < outch; q += 4) + { + const Mat k0 = kernel_tm.channel(q); + const Mat k1 = kernel_tm.channel(q + 1); + const Mat k2 = kernel_tm.channel(q + 2); + const Mat k3 = kernel_tm.channel(q + 3); + + Mat kernel_tm = kernel_tm_pack8.channel(q / 8 + (q % 8) / 4); + + for (int k = 0; k < 36; k++) + { + short* g00 = kernel_tm.row(k); + + for (int p = 0; p + 7 < inch; p += 8) + { + for (int i = 0; i < 8; i++) + { + const short* k00 = k0.row(p + i); + const short* k10 = k1.row(p + i); + const short* k20 = k2.row(p + i); + const short* k30 = k3.row(p + i); + + g00[0] = k00[k]; + g00[1] = k10[k]; + g00[2] = k20[k]; + g00[3] = k30[k]; + + g00 += 4; + } + } + } + } } -static void conv3x3s1_winograd42_pack8_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) +static void conv3x3s1_winograd42_pack8to4_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -506,14 +538,22 @@ static void conv3x3s1_winograd42_pack8_int8_neon(const Mat& bottom_blob, Mat& to bottom_blob_tm = Mat(); // permute end - top_blob_tm.create(tiles, 36, outch, 4u * elempack, elempack, opt.workspace_allocator); + top_blob_tm.create(tiles, 36, outch, 4u * 4, 4, opt.workspace_allocator); + + int nn_outch = 0; + int remain_outch_start = 0; + + nn_outch = outch >> 1; #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) + for (int pp = 0; pp < nn_outch; pp++) { + int p = pp * 2; + int* output0_tm = top_blob_tm.channel(p); + int* output1_tm = top_blob_tm.channel(p + 1); - const Mat kernel0_tm = kernel_tm.channel(p); + const Mat kernel0_tm = kernel_tm.channel(p / 2); for (int r = 0; r < 36; r++) { @@ -529,22 +569,22 @@ static void conv3x3s1_winograd42_pack8_int8_neon(const Mat& bottom_blob, Mat& to int nn = inch; // inch always > 0 asm volatile( - "ld1 {v0.8h, v1.8h}, [%2], #32 \n" // r01 + "ld1 {v0.8h, v1.8h}, [%3], #32 \n" // r01 "eor v8.16b, v8.16b, v8.16b \n" "eor v9.16b, v9.16b, v9.16b \n" - "ld1 {v4.8h, v5.8h}, [%3], #32 \n" // w01 + "ld1 {v4.8h, v5.8h}, [%4], #32 \n" // w01 "eor v10.16b, v10.16b, v10.16b \n" "eor v11.16b, v11.16b, v11.16b \n" - "prfm pldl1keep, [%2, #256] \n" + "prfm pldl1keep, [%3, #256] \n" "eor v12.16b, v12.16b, v12.16b \n" "eor v13.16b, v13.16b, v13.16b \n" - "prfm pldl1keep, [%3, #256] \n" + "prfm pldl1keep, [%4, #256] \n" "eor v14.16b, v14.16b, v14.16b \n" "eor v15.16b, v15.16b, v15.16b \n" @@ -568,284 +608,286 @@ static void conv3x3s1_winograd42_pack8_int8_neon(const Mat& bottom_blob, Mat& to "0: \n" "smlal v8.4s, v4.4h, v0.h[0] \n" - "smlal2 v9.4s, v4.8h, v0.h[0] \n" - "smlal v10.4s, v4.4h, v0.h[1] \n" - "smlal2 v11.4s, v4.8h, v0.h[1] \n" - "smlal v12.4s, v4.4h, v0.h[2] \n" - "smlal2 v13.4s, v4.8h, v0.h[2] \n" - "smlal v14.4s, v4.4h, v0.h[3] \n" - "smlal2 v15.4s, v4.8h, v0.h[3] \n" - "smlal v16.4s, v4.4h, v0.h[4] \n" - "smlal2 v17.4s, v4.8h, v0.h[4] \n" - "smlal v18.4s, v4.4h, v0.h[5] \n" - "smlal2 v19.4s, v4.8h, v0.h[5] \n" - "smlal v20.4s, v4.4h, v0.h[6] \n" - "smlal2 v21.4s, v4.8h, v0.h[6] \n" - "smlal v22.4s, v4.4h, v0.h[7] \n" - "smlal2 v23.4s, v4.8h, v0.h[7] \n" - - "ld1 {v2.8h, v3.8h}, [%2], #32 \n" // r23 + "smlal2 v20.4s, v4.8h, v0.h[0] \n" + "smlal v9.4s, v4.4h, v0.h[1] \n" + "smlal2 v21.4s, v4.8h, v0.h[1] \n" + "smlal v10.4s, v4.4h, v0.h[2] \n" + "smlal2 v22.4s, v4.8h, v0.h[2] \n" + "smlal v11.4s, v4.4h, v0.h[3] \n" + "smlal2 v23.4s, v4.8h, v0.h[3] \n" + "smlal v12.4s, v4.4h, v0.h[4] \n" + "smlal2 v24.4s, v4.8h, v0.h[4] \n" + "smlal v13.4s, v4.4h, v0.h[5] \n" + "smlal2 v25.4s, v4.8h, v0.h[5] \n" + "smlal v14.4s, v4.4h, v0.h[6] \n" + "smlal2 v26.4s, v4.8h, v0.h[6] \n" + "smlal v15.4s, v4.4h, v0.h[7] \n" + "smlal2 v27.4s, v4.8h, v0.h[7] \n" + + "ld1 {v2.8h, v3.8h}, [%3], #32 \n" // r23 + + "smlal v16.4s, v4.4h, v1.h[0] \n" + "smlal2 v28.4s, v4.8h, v1.h[0] \n" + "smlal v17.4s, v4.4h, v1.h[1] \n" + "smlal2 v29.4s, v4.8h, v1.h[1] \n" - "smlal v24.4s, v4.4h, v1.h[0] \n" - "smlal2 v25.4s, v4.8h, v1.h[0] \n" - "smlal v26.4s, v4.4h, v1.h[1] \n" - "smlal2 v27.4s, v4.8h, v1.h[1] \n" - - "prfm pldl1keep, [%2, #256] \n" + "prfm pldl1keep, [%3, #256] \n" - "smlal v28.4s, v4.4h, v1.h[2] \n" - "smlal2 v29.4s, v4.8h, v1.h[2] \n" - "smlal v30.4s, v4.4h, v1.h[3] \n" + "smlal v18.4s, v4.4h, v1.h[2] \n" + "smlal2 v30.4s, v4.8h, v1.h[2] \n" + "smlal v19.4s, v4.4h, v1.h[3] \n" "smlal2 v31.4s, v4.8h, v1.h[3] \n" - "ld1 {v6.8h, v7.8h}, [%3], #32 \n" // w23 + "ld1 {v6.8h, v7.8h}, [%4], #32 \n" // w23 "smlal v8.4s, v5.4h, v1.h[4] \n" - "smlal2 v9.4s, v5.8h, v1.h[4] \n" - "smlal v10.4s, v5.4h, v1.h[5] \n" - "smlal2 v11.4s, v5.8h, v1.h[5] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal v12.4s, v5.4h, v1.h[6] \n" - "smlal2 v13.4s, v5.8h, v1.h[6] \n" - "smlal v14.4s, v5.4h, v1.h[7] \n" - "smlal2 v15.4s, v5.8h, v1.h[7] \n" - "smlal v16.4s, v5.4h, v2.h[0] \n" - "smlal2 v17.4s, v5.8h, v2.h[0] \n" - "smlal v18.4s, v5.4h, v2.h[1] \n" - "smlal2 v19.4s, v5.8h, v2.h[1] \n" - "smlal v20.4s, v5.4h, v2.h[2] \n" - "smlal2 v21.4s, v5.8h, v2.h[2] \n" - "smlal v22.4s, v5.4h, v2.h[3] \n" - "smlal2 v23.4s, v5.8h, v2.h[3] \n" - "smlal v24.4s, v5.4h, v2.h[4] \n" - "smlal2 v25.4s, v5.8h, v2.h[4] \n" - "smlal v26.4s, v5.4h, v2.h[5] \n" - "smlal2 v27.4s, v5.8h, v2.h[5] \n" - "smlal v28.4s, v5.4h, v2.h[6] \n" - "smlal2 v29.4s, v5.8h, v2.h[6] \n" - "smlal v30.4s, v5.4h, v2.h[7] \n" + "smlal2 v20.4s, v5.8h, v1.h[4] \n" + "smlal v9.4s, v5.4h, v1.h[5] \n" + "smlal2 v21.4s, v5.8h, v1.h[5] \n" + + "prfm pldl1keep, [%4, #256] \n" + + "smlal v10.4s, v5.4h, v1.h[6] \n" + "smlal2 v22.4s, v5.8h, v1.h[6] \n" + "smlal v11.4s, v5.4h, v1.h[7] \n" + "smlal2 v23.4s, v5.8h, v1.h[7] \n" + "smlal v12.4s, v5.4h, v2.h[0] \n" + "smlal2 v24.4s, v5.8h, v2.h[0] \n" + "smlal v13.4s, v5.4h, v2.h[1] \n" + "smlal2 v25.4s, v5.8h, v2.h[1] \n" + "smlal v14.4s, v5.4h, v2.h[2] \n" + "smlal2 v26.4s, v5.8h, v2.h[2] \n" + "smlal v15.4s, v5.4h, v2.h[3] \n" + "smlal2 v27.4s, v5.8h, v2.h[3] \n" + "smlal v16.4s, v5.4h, v2.h[4] \n" + "smlal2 v28.4s, v5.8h, v2.h[4] \n" + "smlal v17.4s, v5.4h, v2.h[5] \n" + "smlal2 v29.4s, v5.8h, v2.h[5] \n" + "smlal v18.4s, v5.4h, v2.h[6] \n" + "smlal2 v30.4s, v5.8h, v2.h[6] \n" + "smlal v19.4s, v5.4h, v2.h[7] \n" "smlal2 v31.4s, v5.8h, v2.h[7] \n" - "ld1 {v0.8h, v1.8h}, [%2], #32 \n" // r45 + "ld1 {v0.8h, v1.8h}, [%3], #32 \n" // r45 "smlal v8.4s, v6.4h, v3.h[0] \n" - "smlal2 v9.4s, v6.8h, v3.h[0] \n" - "smlal v10.4s, v6.4h, v3.h[1] \n" - "smlal2 v11.4s, v6.8h, v3.h[1] \n" + "smlal2 v20.4s, v6.8h, v3.h[0] \n" + "smlal v9.4s, v6.4h, v3.h[1] \n" + "smlal2 v21.4s, v6.8h, v3.h[1] \n" - "prfm pldl1keep, [%2, #256] \n" + "prfm pldl1keep, [%3, #256] \n" - "smlal v12.4s, v6.4h, v3.h[2] \n" - "smlal2 v13.4s, v6.8h, v3.h[2] \n" - "smlal v14.4s, v6.4h, v3.h[3] \n" - "smlal2 v15.4s, v6.8h, v3.h[3] \n" - "smlal v16.4s, v6.4h, v3.h[4] \n" - "smlal2 v17.4s, v6.8h, v3.h[4] \n" - "smlal v18.4s, v6.4h, v3.h[5] \n" - "smlal2 v19.4s, v6.8h, v3.h[5] \n" - "smlal v20.4s, v6.4h, v3.h[6] \n" - "smlal2 v21.4s, v6.8h, v3.h[6] \n" - "smlal v22.4s, v6.4h, v3.h[7] \n" - "smlal2 v23.4s, v6.8h, v3.h[7] \n" - - "smlal v24.4s, v6.4h, v0.h[0] \n" - "smlal2 v25.4s, v6.8h, v0.h[0] \n" - "smlal v26.4s, v6.4h, v0.h[1] \n" - "smlal2 v27.4s, v6.8h, v0.h[1] \n" - "smlal v28.4s, v6.4h, v0.h[2] \n" - "smlal2 v29.4s, v6.8h, v0.h[2] \n" - "smlal v30.4s, v6.4h, v0.h[3] \n" + "smlal v10.4s, v6.4h, v3.h[2] \n" + "smlal2 v22.4s, v6.8h, v3.h[2] \n" + "smlal v11.4s, v6.4h, v3.h[3] \n" + "smlal2 v23.4s, v6.8h, v3.h[3] \n" + "smlal v12.4s, v6.4h, v3.h[4] \n" + "smlal2 v24.4s, v6.8h, v3.h[4] \n" + "smlal v13.4s, v6.4h, v3.h[5] \n" + "smlal2 v25.4s, v6.8h, v3.h[5] \n" + "smlal v14.4s, v6.4h, v3.h[6] \n" + "smlal2 v26.4s, v6.8h, v3.h[6] \n" + "smlal v15.4s, v6.4h, v3.h[7] \n" + "smlal2 v27.4s, v6.8h, v3.h[7] \n" + + "smlal v16.4s, v6.4h, v0.h[0] \n" + "smlal2 v28.4s, v6.8h, v0.h[0] \n" + "smlal v17.4s, v6.4h, v0.h[1] \n" + "smlal2 v29.4s, v6.8h, v0.h[1] \n" + "smlal v18.4s, v6.4h, v0.h[2] \n" + "smlal2 v30.4s, v6.8h, v0.h[2] \n" + "smlal v19.4s, v6.4h, v0.h[3] \n" "smlal2 v31.4s, v6.8h, v0.h[3] \n" - "ld1 {v4.8h, v5.8h}, [%3], #32 \n" // w45 + "ld1 {v4.8h, v5.8h}, [%4], #32 \n" // w45 "smlal v8.4s, v7.4h, v0.h[4] \n" - "smlal2 v9.4s, v7.8h, v0.h[4] \n" - "smlal v10.4s, v7.4h, v0.h[5] \n" - "smlal2 v11.4s, v7.8h, v0.h[5] \n" + "smlal2 v20.4s, v7.8h, v0.h[4] \n" + "smlal v9.4s, v7.4h, v0.h[5] \n" + "smlal2 v21.4s, v7.8h, v0.h[5] \n" - "prfm pldl1keep, [%3, #256] \n" + "prfm pldl1keep, [%4, #256] \n" - "smlal v12.4s, v7.4h, v0.h[6] \n" - "smlal2 v13.4s, v7.8h, v0.h[6] \n" - "smlal v14.4s, v7.4h, v0.h[7] \n" - "smlal2 v15.4s, v7.8h, v0.h[7] \n" + "smlal v10.4s, v7.4h, v0.h[6] \n" + "smlal2 v22.4s, v7.8h, v0.h[6] \n" + "smlal v11.4s, v7.4h, v0.h[7] \n" + "smlal2 v23.4s, v7.8h, v0.h[7] \n" - "ld1 {v2.8h, v3.8h}, [%2], #32 \n" // r67 + "ld1 {v2.8h, v3.8h}, [%3], #32 \n" // r67 - "smlal v16.4s, v7.4h, v1.h[0] \n" - "smlal2 v17.4s, v7.8h, v1.h[0] \n" - "smlal v18.4s, v7.4h, v1.h[1] \n" - "smlal2 v19.4s, v7.8h, v1.h[1] \n" + "smlal v12.4s, v7.4h, v1.h[0] \n" + "smlal2 v24.4s, v7.8h, v1.h[0] \n" + "smlal v13.4s, v7.4h, v1.h[1] \n" + "smlal2 v25.4s, v7.8h, v1.h[1] \n" - "prfm pldl1keep, [%2, #256] \n" + "prfm pldl1keep, [%3, #256] \n" - "smlal v20.4s, v7.4h, v1.h[2] \n" - "smlal2 v21.4s, v7.8h, v1.h[2] \n" - "smlal v22.4s, v7.4h, v1.h[3] \n" - "smlal2 v23.4s, v7.8h, v1.h[3] \n" - "smlal v24.4s, v7.4h, v1.h[4] \n" - "smlal2 v25.4s, v7.8h, v1.h[4] \n" - "smlal v26.4s, v7.4h, v1.h[5] \n" - "smlal2 v27.4s, v7.8h, v1.h[5] \n" - "smlal v28.4s, v7.4h, v1.h[6] \n" - "smlal2 v29.4s, v7.8h, v1.h[6] \n" - "smlal v30.4s, v7.4h, v1.h[7] \n" + "smlal v14.4s, v7.4h, v1.h[2] \n" + "smlal2 v26.4s, v7.8h, v1.h[2] \n" + "smlal v15.4s, v7.4h, v1.h[3] \n" + "smlal2 v27.4s, v7.8h, v1.h[3] \n" + "smlal v16.4s, v7.4h, v1.h[4] \n" + "smlal2 v28.4s, v7.8h, v1.h[4] \n" + "smlal v17.4s, v7.4h, v1.h[5] \n" + "smlal2 v29.4s, v7.8h, v1.h[5] \n" + "smlal v18.4s, v7.4h, v1.h[6] \n" + "smlal2 v30.4s, v7.8h, v1.h[6] \n" + "smlal v19.4s, v7.4h, v1.h[7] \n" "smlal2 v31.4s, v7.8h, v1.h[7] \n" "smlal v8.4s, v4.4h, v2.h[0] \n" - "smlal2 v9.4s, v4.8h, v2.h[0] \n" - "smlal v10.4s, v4.4h, v2.h[1] \n" - "smlal2 v11.4s, v4.8h, v2.h[1] \n" - "smlal v12.4s, v4.4h, v2.h[2] \n" - "smlal2 v13.4s, v4.8h, v2.h[2] \n" - "smlal v14.4s, v4.4h, v2.h[3] \n" - "smlal2 v15.4s, v4.8h, v2.h[3] \n" - "smlal v16.4s, v4.4h, v2.h[4] \n" - "smlal2 v17.4s, v4.8h, v2.h[4] \n" - "smlal v18.4s, v4.4h, v2.h[5] \n" - "smlal2 v19.4s, v4.8h, v2.h[5] \n" - "smlal v20.4s, v4.4h, v2.h[6] \n" - "smlal2 v21.4s, v4.8h, v2.h[6] \n" - "smlal v22.4s, v4.4h, v2.h[7] \n" - "smlal2 v23.4s, v4.8h, v2.h[7] \n" - - "ld1 {v0.8h, v1.8h}, [%2], #32 \n" // r89 + "smlal2 v20.4s, v4.8h, v2.h[0] \n" + "smlal v9.4s, v4.4h, v2.h[1] \n" + "smlal2 v21.4s, v4.8h, v2.h[1] \n" + "smlal v10.4s, v4.4h, v2.h[2] \n" + "smlal2 v22.4s, v4.8h, v2.h[2] \n" + "smlal v11.4s, v4.4h, v2.h[3] \n" + "smlal2 v23.4s, v4.8h, v2.h[3] \n" + "smlal v12.4s, v4.4h, v2.h[4] \n" + "smlal2 v24.4s, v4.8h, v2.h[4] \n" + "smlal v13.4s, v4.4h, v2.h[5] \n" + "smlal2 v25.4s, v4.8h, v2.h[5] \n" + "smlal v14.4s, v4.4h, v2.h[6] \n" + "smlal2 v26.4s, v4.8h, v2.h[6] \n" + "smlal v15.4s, v4.4h, v2.h[7] \n" + "smlal2 v27.4s, v4.8h, v2.h[7] \n" + + "ld1 {v0.8h, v1.8h}, [%3], #32 \n" // r89 + + "smlal v16.4s, v4.4h, v3.h[0] \n" + "smlal2 v28.4s, v4.8h, v3.h[0] \n" + "smlal v17.4s, v4.4h, v3.h[1] \n" + "smlal2 v29.4s, v4.8h, v3.h[1] \n" - "smlal v24.4s, v4.4h, v3.h[0] \n" - "smlal2 v25.4s, v4.8h, v3.h[0] \n" - "smlal v26.4s, v4.4h, v3.h[1] \n" - "smlal2 v27.4s, v4.8h, v3.h[1] \n" - - "prfm pldl1keep, [%2, #256] \n" + "prfm pldl1keep, [%3, #256] \n" - "smlal v28.4s, v4.4h, v3.h[2] \n" - "smlal2 v29.4s, v4.8h, v3.h[2] \n" - "smlal v30.4s, v4.4h, v3.h[3] \n" + "smlal v18.4s, v4.4h, v3.h[2] \n" + "smlal2 v30.4s, v4.8h, v3.h[2] \n" + "smlal v19.4s, v4.4h, v3.h[3] \n" "smlal2 v31.4s, v4.8h, v3.h[3] \n" - "ld1 {v6.8h, v7.8h}, [%3], #32 \n" // w67 + "ld1 {v6.8h, v7.8h}, [%4], #32 \n" // w67 "smlal v8.4s, v5.4h, v3.h[4] \n" - "smlal2 v9.4s, v5.8h, v3.h[4] \n" - "smlal v10.4s, v5.4h, v3.h[5] \n" - "smlal2 v11.4s, v5.8h, v3.h[5] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal v12.4s, v5.4h, v3.h[6] \n" - "smlal2 v13.4s, v5.8h, v3.h[6] \n" - "smlal v14.4s, v5.4h, v3.h[7] \n" - "smlal2 v15.4s, v5.8h, v3.h[7] \n" - - "smlal v16.4s, v5.4h, v0.h[0] \n" - "smlal2 v17.4s, v5.8h, v0.h[0] \n" - "smlal v18.4s, v5.4h, v0.h[1] \n" - "smlal2 v19.4s, v5.8h, v0.h[1] \n" - "smlal v20.4s, v5.4h, v0.h[2] \n" - "smlal2 v21.4s, v5.8h, v0.h[2] \n" - "smlal v22.4s, v5.4h, v0.h[3] \n" - "smlal2 v23.4s, v5.8h, v0.h[3] \n" - "smlal v24.4s, v5.4h, v0.h[4] \n" - "smlal2 v25.4s, v5.8h, v0.h[4] \n" - "smlal v26.4s, v5.4h, v0.h[5] \n" - "smlal2 v27.4s, v5.8h, v0.h[5] \n" - "smlal v28.4s, v5.4h, v0.h[6] \n" - "smlal2 v29.4s, v5.8h, v0.h[6] \n" - "smlal v30.4s, v5.4h, v0.h[7] \n" + "smlal2 v20.4s, v5.8h, v3.h[4] \n" + "smlal v9.4s, v5.4h, v3.h[5] \n" + "smlal2 v21.4s, v5.8h, v3.h[5] \n" + + "prfm pldl1keep, [%4, #256] \n" + + "smlal v10.4s, v5.4h, v3.h[6] \n" + "smlal2 v22.4s, v5.8h, v3.h[6] \n" + "smlal v11.4s, v5.4h, v3.h[7] \n" + "smlal2 v23.4s, v5.8h, v3.h[7] \n" + + "smlal v12.4s, v5.4h, v0.h[0] \n" + "smlal2 v24.4s, v5.8h, v0.h[0] \n" + "smlal v13.4s, v5.4h, v0.h[1] \n" + "smlal2 v25.4s, v5.8h, v0.h[1] \n" + "smlal v14.4s, v5.4h, v0.h[2] \n" + "smlal2 v26.4s, v5.8h, v0.h[2] \n" + "smlal v15.4s, v5.4h, v0.h[3] \n" + "smlal2 v27.4s, v5.8h, v0.h[3] \n" + "smlal v16.4s, v5.4h, v0.h[4] \n" + "smlal2 v28.4s, v5.8h, v0.h[4] \n" + "smlal v17.4s, v5.4h, v0.h[5] \n" + "smlal2 v29.4s, v5.8h, v0.h[5] \n" + "smlal v18.4s, v5.4h, v0.h[6] \n" + "smlal2 v30.4s, v5.8h, v0.h[6] \n" + "smlal v19.4s, v5.4h, v0.h[7] \n" "smlal2 v31.4s, v5.8h, v0.h[7] \n" - "ld1 {v2.8h, v3.8h}, [%2], #32 \n" // r1011 + "ld1 {v2.8h, v3.8h}, [%3], #32 \n" // r1011 "smlal v8.4s, v6.4h, v1.h[0] \n" - "smlal2 v9.4s, v6.8h, v1.h[0] \n" - "smlal v10.4s, v6.4h, v1.h[1] \n" - "smlal2 v11.4s, v6.8h, v1.h[1] \n" + "smlal2 v20.4s, v6.8h, v1.h[0] \n" + "smlal v9.4s, v6.4h, v1.h[1] \n" + "smlal2 v21.4s, v6.8h, v1.h[1] \n" - "prfm pldl1keep, [%2, #256] \n" + "prfm pldl1keep, [%3, #256] \n" - "smlal v12.4s, v6.4h, v1.h[2] \n" - "smlal2 v13.4s, v6.8h, v1.h[2] \n" - "smlal v14.4s, v6.4h, v1.h[3] \n" - "smlal2 v15.4s, v6.8h, v1.h[3] \n" - "smlal v16.4s, v6.4h, v1.h[4] \n" - "smlal2 v17.4s, v6.8h, v1.h[4] \n" - "smlal v18.4s, v6.4h, v1.h[5] \n" - "smlal2 v19.4s, v6.8h, v1.h[5] \n" - "smlal v20.4s, v6.4h, v1.h[6] \n" - "smlal2 v21.4s, v6.8h, v1.h[6] \n" - "smlal v22.4s, v6.4h, v1.h[7] \n" - "smlal2 v23.4s, v6.8h, v1.h[7] \n" - "smlal v24.4s, v6.4h, v2.h[0] \n" - "smlal2 v25.4s, v6.8h, v2.h[0] \n" - "smlal v26.4s, v6.4h, v2.h[1] \n" - "smlal2 v27.4s, v6.8h, v2.h[1] \n" - "smlal v28.4s, v6.4h, v2.h[2] \n" - "smlal2 v29.4s, v6.8h, v2.h[2] \n" - "smlal v30.4s, v6.4h, v2.h[3] \n" + "smlal v10.4s, v6.4h, v1.h[2] \n" + "smlal2 v22.4s, v6.8h, v1.h[2] \n" + "smlal v11.4s, v6.4h, v1.h[3] \n" + "smlal2 v23.4s, v6.8h, v1.h[3] \n" + "smlal v12.4s, v6.4h, v1.h[4] \n" + "smlal2 v24.4s, v6.8h, v1.h[4] \n" + "smlal v13.4s, v6.4h, v1.h[5] \n" + "smlal2 v25.4s, v6.8h, v1.h[5] \n" + "smlal v14.4s, v6.4h, v1.h[6] \n" + "smlal2 v26.4s, v6.8h, v1.h[6] \n" + "smlal v15.4s, v6.4h, v1.h[7] \n" + "smlal2 v27.4s, v6.8h, v1.h[7] \n" + "smlal v16.4s, v6.4h, v2.h[0] \n" + "smlal2 v28.4s, v6.8h, v2.h[0] \n" + "smlal v17.4s, v6.4h, v2.h[1] \n" + "smlal2 v29.4s, v6.8h, v2.h[1] \n" + "smlal v18.4s, v6.4h, v2.h[2] \n" + "smlal2 v30.4s, v6.8h, v2.h[2] \n" + "smlal v19.4s, v6.4h, v2.h[3] \n" "smlal2 v31.4s, v6.8h, v2.h[3] \n" - "ld1 {v4.8h, v5.8h}, [%3], #32 \n" // w01 + "ld1 {v4.8h, v5.8h}, [%4], #32 \n" // w01 "smlal v8.4s, v7.4h, v2.h[4] \n" - "smlal2 v9.4s, v7.8h, v2.h[4] \n" - "smlal v10.4s, v7.4h, v2.h[5] \n" - "smlal2 v11.4s, v7.8h, v2.h[5] \n" + "smlal2 v20.4s, v7.8h, v2.h[4] \n" + "smlal v9.4s, v7.4h, v2.h[5] \n" + "smlal2 v21.4s, v7.8h, v2.h[5] \n" - "prfm pldl1keep, [%3, #256] \n" + "prfm pldl1keep, [%4, #256] \n" - "smlal v12.4s, v7.4h, v2.h[6] \n" - "smlal2 v13.4s, v7.8h, v2.h[6] \n" - "smlal v14.4s, v7.4h, v2.h[7] \n" - "smlal2 v15.4s, v7.8h, v2.h[7] \n" + "smlal v10.4s, v7.4h, v2.h[6] \n" + "smlal2 v22.4s, v7.8h, v2.h[6] \n" + "smlal v11.4s, v7.4h, v2.h[7] \n" + "smlal2 v23.4s, v7.8h, v2.h[7] \n" - "ld1 {v0.8h, v1.8h}, [%2], #32 \n" // r01 + "ld1 {v0.8h, v1.8h}, [%3], #32 \n" // r01 - "smlal v16.4s, v7.4h, v3.h[0] \n" - "smlal2 v17.4s, v7.8h, v3.h[0] \n" - "smlal v18.4s, v7.4h, v3.h[1] \n" - "smlal2 v19.4s, v7.8h, v3.h[1] \n" + "smlal v12.4s, v7.4h, v3.h[0] \n" + "smlal2 v24.4s, v7.8h, v3.h[0] \n" + "smlal v13.4s, v7.4h, v3.h[1] \n" + "smlal2 v25.4s, v7.8h, v3.h[1] \n" - "prfm pldl1keep, [%2, #256] \n" + "prfm pldl1keep, [%3, #256] \n" - "smlal v20.4s, v7.4h, v3.h[2] \n" - "smlal2 v21.4s, v7.8h, v3.h[2] \n" - "smlal v22.4s, v7.4h, v3.h[3] \n" - "smlal2 v23.4s, v7.8h, v3.h[3] \n" - "smlal v24.4s, v7.4h, v3.h[4] \n" - "smlal2 v25.4s, v7.8h, v3.h[4] \n" - "smlal v26.4s, v7.4h, v3.h[5] \n" - "smlal2 v27.4s, v7.8h, v3.h[5] \n" + "smlal v14.4s, v7.4h, v3.h[2] \n" + "smlal2 v26.4s, v7.8h, v3.h[2] \n" + "smlal v15.4s, v7.4h, v3.h[3] \n" + "smlal2 v27.4s, v7.8h, v3.h[3] \n" + "smlal v16.4s, v7.4h, v3.h[4] \n" + "smlal2 v28.4s, v7.8h, v3.h[4] \n" + "smlal v17.4s, v7.4h, v3.h[5] \n" + "smlal2 v29.4s, v7.8h, v3.h[5] \n" "subs %w0, %w0, #1 \n" - "smlal v28.4s, v7.4h, v3.h[6] \n" - "smlal2 v29.4s, v7.8h, v3.h[6] \n" - "smlal v30.4s, v7.4h, v3.h[7] \n" + "smlal v18.4s, v7.4h, v3.h[6] \n" + "smlal2 v30.4s, v7.8h, v3.h[6] \n" + "smlal v19.4s, v7.4h, v3.h[7] \n" "smlal2 v31.4s, v7.8h, v3.h[7] \n" "bne 0b \n" - "sub %2, %2, #32 \n" "sub %3, %3, #32 \n" + "sub %4, %4, #32 \n" "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%1], #64 \n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%2], #64 \n" "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%1], #64 \n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%2], #64 \n" "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%1], #64 \n" - "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%1], #64 \n" - "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%1], #64 \n" - "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%1], #64 \n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%2], #64 \n" : "=r"(nn), // %0 "=r"(output0_tm), // %1 - "=r"(r0), // %2 - "=r"(k0) // %3 + "=r"(output1_tm), // %2 + "=r"(r0), // %3 + "=r"(k0) // %4 : "0"(nn), "1"(output0_tm), - "2"(r0), - "3"(k0) + "2"(output1_tm), + "3"(r0), + "4"(k0) : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); } for (; i + 7 < tiles; i += 8) @@ -1040,22 +1082,23 @@ static void conv3x3s1_winograd42_pack8_int8_neon(const Mat& bottom_blob, Mat& to } vst1q_s32(output0_tm, _sum0); - vst1q_s32(output0_tm + 4, _sum1); - vst1q_s32(output0_tm + 8, _sum2); - vst1q_s32(output0_tm + 12, _sum3); - vst1q_s32(output0_tm + 16, _sum4); - vst1q_s32(output0_tm + 20, _sum5); - vst1q_s32(output0_tm + 24, _sum6); - vst1q_s32(output0_tm + 28, _sum7); - vst1q_s32(output0_tm + 32, _sum8); - vst1q_s32(output0_tm + 36, _sum9); - vst1q_s32(output0_tm + 40, _suma); - vst1q_s32(output0_tm + 44, _sumb); - vst1q_s32(output0_tm + 48, _sumc); - vst1q_s32(output0_tm + 52, _sumd); - vst1q_s32(output0_tm + 56, _sume); - vst1q_s32(output0_tm + 60, _sumf); - output0_tm += 64; + vst1q_s32(output1_tm, _sum1); + vst1q_s32(output0_tm + 4, _sum2); + vst1q_s32(output1_tm + 4, _sum3); + vst1q_s32(output0_tm + 8, _sum4); + vst1q_s32(output1_tm + 8, _sum5); + vst1q_s32(output0_tm + 12, _sum6); + vst1q_s32(output1_tm + 12, _sum7); + vst1q_s32(output0_tm + 16, _sum8); + vst1q_s32(output1_tm + 16, _sum9); + vst1q_s32(output0_tm + 20, _suma); + vst1q_s32(output1_tm + 20, _sumb); + vst1q_s32(output0_tm + 24, _sumc); + vst1q_s32(output1_tm + 24, _sumd); + vst1q_s32(output0_tm + 28, _sume); + vst1q_s32(output1_tm + 28, _sumf); + output0_tm += 32; + output1_tm += 32; } #endif // __aarch64__ for (; i + 3 < tiles; i += 4) @@ -1179,14 +1222,15 @@ static void conv3x3s1_winograd42_pack8_int8_neon(const Mat& bottom_blob, Mat& to } vst1q_s32(output0_tm, _sum0); - vst1q_s32(output0_tm + 4, _sum1); - vst1q_s32(output0_tm + 8, _sum2); - vst1q_s32(output0_tm + 12, _sum3); - vst1q_s32(output0_tm + 16, _sum4); - vst1q_s32(output0_tm + 20, _sum5); - vst1q_s32(output0_tm + 24, _sum6); - vst1q_s32(output0_tm + 28, _sum7); - output0_tm += 32; + vst1q_s32(output1_tm, _sum1); + vst1q_s32(output0_tm + 4, _sum2); + vst1q_s32(output1_tm + 4, _sum3); + vst1q_s32(output0_tm + 8, _sum4); + vst1q_s32(output1_tm + 8, _sum5); + vst1q_s32(output0_tm + 12, _sum6); + vst1q_s32(output1_tm + 12, _sum7); + output0_tm += 16; + output1_tm += 16; #else asm volatile( "veor q8, q8 \n" @@ -1200,119 +1244,120 @@ static void conv3x3s1_winograd42_pack8_int8_neon(const Mat& bottom_blob, Mat& to "0: \n" - "pld [%1, #256] \n" - "pld [%1, #512] \n" - "vldm %1!, {d0-d7} \n" + "pld [%3, #256] \n" + "pld [%3, #512] \n" + "vldm %3!, {d0-d7} \n" - "pld [%2, #256] \n" - "vld1.s16 {d8-d11}, [%2 :128]! \n" + "pld [%4, #256] \n" + "vld1.s16 {d8-d11}, [%4 :128]! \n" "vmlal.s16 q8, d8, d0[0] \n" - "vmlal.s16 q9, d9, d0[0] \n" - "vmlal.s16 q10, d8, d2[0] \n" - "vmlal.s16 q11, d9, d2[0] \n" - "vmlal.s16 q12, d8, d4[0] \n" - "vmlal.s16 q13, d9, d4[0] \n" - "vmlal.s16 q14, d8, d6[0] \n" + "vmlal.s16 q12, d9, d0[0] \n" + "vmlal.s16 q9, d8, d2[0] \n" + "vmlal.s16 q13, d9, d2[0] \n" + "vmlal.s16 q10, d8, d4[0] \n" + "vmlal.s16 q14, d9, d4[0] \n" + "vmlal.s16 q11, d8, d6[0] \n" "vmlal.s16 q15, d9, d6[0] \n" - "pld [%2, #128] \n" - "vld1.s16 {d8-d9}, [%2 :128]! \n" + "pld [%4, #128] \n" + "vld1.s16 {d8-d9}, [%4 :128]! \n" "vmlal.s16 q8, d10, d0[1] \n" - "vmlal.s16 q9, d11, d0[1] \n" - "vmlal.s16 q10, d10, d2[1] \n" - "vmlal.s16 q11, d11, d2[1] \n" - "vmlal.s16 q12, d10, d4[1] \n" - "vmlal.s16 q13, d11, d4[1] \n" - "vmlal.s16 q14, d10, d6[1] \n" + "vmlal.s16 q12, d11, d0[1] \n" + "vmlal.s16 q9, d10, d2[1] \n" + "vmlal.s16 q13, d11, d2[1] \n" + "vmlal.s16 q10, d10, d4[1] \n" + "vmlal.s16 q14, d11, d4[1] \n" + "vmlal.s16 q11, d10, d6[1] \n" "vmlal.s16 q15, d11, d6[1] \n" - "pld [%2, #128] \n" - "vld1.s16 {d10-d11}, [%2 :128]! \n" + "pld [%4, #128] \n" + "vld1.s16 {d10-d11}, [%4 :128]! \n" "vmlal.s16 q8, d8, d0[2] \n" - "vmlal.s16 q9, d9, d0[2] \n" - "vmlal.s16 q10, d8, d2[2] \n" - "vmlal.s16 q11, d9, d2[2] \n" - "vmlal.s16 q12, d8, d4[2] \n" - "vmlal.s16 q13, d9, d4[2] \n" - "vmlal.s16 q14, d8, d6[2] \n" + "vmlal.s16 q12, d9, d0[2] \n" + "vmlal.s16 q9, d8, d2[2] \n" + "vmlal.s16 q13, d9, d2[2] \n" + "vmlal.s16 q10, d8, d4[2] \n" + "vmlal.s16 q14, d9, d4[2] \n" + "vmlal.s16 q11, d8, d6[2] \n" "vmlal.s16 q15, d9, d6[2] \n" - "pld [%2, #128] \n" - "vld1.s16 {d8-d9}, [%2 :128]! \n" + "pld [%4, #128] \n" + "vld1.s16 {d8-d9}, [%4 :128]! \n" "vmlal.s16 q8, d10, d0[3] \n" - "vmlal.s16 q9, d11, d0[3] \n" - "vmlal.s16 q10, d10, d2[3] \n" - "vmlal.s16 q11, d11, d2[3] \n" - "vmlal.s16 q12, d10, d4[3] \n" - "vmlal.s16 q13, d11, d4[3] \n" - "vmlal.s16 q14, d10, d6[3] \n" + "vmlal.s16 q12, d11, d0[3] \n" + "vmlal.s16 q9, d10, d2[3] \n" + "vmlal.s16 q13, d11, d2[3] \n" + "vmlal.s16 q10, d10, d4[3] \n" + "vmlal.s16 q14, d11, d4[3] \n" + "vmlal.s16 q11, d10, d6[3] \n" "vmlal.s16 q15, d11, d6[3] \n" - "pld [%2, #128] \n" - "vld1.s16 {d10-d11}, [%2 :128]! \n" + "pld [%4, #128] \n" + "vld1.s16 {d10-d11}, [%4 :128]! \n" "vmlal.s16 q8, d8, d1[0] \n" - "vmlal.s16 q9, d9, d1[0] \n" - "vmlal.s16 q10, d8, d3[0] \n" - "vmlal.s16 q11, d9, d3[0] \n" - "vmlal.s16 q12, d8, d5[0] \n" - "vmlal.s16 q13, d9, d5[0] \n" - "vmlal.s16 q14, d8, d7[0] \n" + "vmlal.s16 q12, d9, d1[0] \n" + "vmlal.s16 q9, d8, d3[0] \n" + "vmlal.s16 q13, d9, d3[0] \n" + "vmlal.s16 q10, d8, d5[0] \n" + "vmlal.s16 q14, d9, d5[0] \n" + "vmlal.s16 q11, d8, d7[0] \n" "vmlal.s16 q15, d9, d7[0] \n" - "pld [%2, #128] \n" - "vld1.s16 {d8-d9}, [%2 :128]! \n" + "pld [%4, #128] \n" + "vld1.s16 {d8-d9}, [%4 :128]! \n" "vmlal.s16 q8, d10, d1[1] \n" - "vmlal.s16 q9, d11, d1[1] \n" - "vmlal.s16 q10, d10, d3[1] \n" - "vmlal.s16 q11, d11, d3[1] \n" - "vmlal.s16 q12, d10, d5[1] \n" - "vmlal.s16 q13, d11, d5[1] \n" - "vmlal.s16 q14, d10, d7[1] \n" + "vmlal.s16 q12, d11, d1[1] \n" + "vmlal.s16 q9, d10, d3[1] \n" + "vmlal.s16 q13, d11, d3[1] \n" + "vmlal.s16 q10, d10, d5[1] \n" + "vmlal.s16 q14, d11, d5[1] \n" + "vmlal.s16 q11, d10, d7[1] \n" "vmlal.s16 q15, d11, d7[1] \n" - "pld [%2, #128] \n" - "vld1.s16 {d10-d11}, [%2 :128]! \n" + "pld [%4, #128] \n" + "vld1.s16 {d10-d11}, [%4 :128]! \n" "vmlal.s16 q8, d8, d1[2] \n" - "vmlal.s16 q9, d9, d1[2] \n" - "vmlal.s16 q10, d8, d3[2] \n" - "vmlal.s16 q11, d9, d3[2] \n" - "vmlal.s16 q12, d8, d5[2] \n" - "vmlal.s16 q13, d9, d5[2] \n" - "vmlal.s16 q14, d8, d7[2] \n" + "vmlal.s16 q12, d9, d1[2] \n" + "vmlal.s16 q9, d8, d3[2] \n" + "vmlal.s16 q13, d9, d3[2] \n" + "vmlal.s16 q10, d8, d5[2] \n" + "vmlal.s16 q14, d9, d5[2] \n" + "vmlal.s16 q11, d8, d7[2] \n" "vmlal.s16 q15, d9, d7[2] \n" - "subs %3, %3, #1 \n" + "subs %0, %0, #1 \n" "vmlal.s16 q8, d10, d1[3] \n" - "vmlal.s16 q9, d11, d1[3] \n" - "vmlal.s16 q10, d10, d3[3] \n" - "vmlal.s16 q11, d11, d3[3] \n" - "vmlal.s16 q12, d10, d5[3] \n" - "vmlal.s16 q13, d11, d5[3] \n" - "vmlal.s16 q14, d10, d7[3] \n" + "vmlal.s16 q12, d11, d1[3] \n" + "vmlal.s16 q9, d10, d3[3] \n" + "vmlal.s16 q13, d11, d3[3] \n" + "vmlal.s16 q10, d10, d5[3] \n" + "vmlal.s16 q14, d11, d5[3] \n" + "vmlal.s16 q11, d10, d7[3] \n" "vmlal.s16 q15, d11, d7[3] \n" "bne 0b \n" - "1: \n" - "vstm %0!, {d16-d23} \n" - "vstm %0!, {d24-d31} \n" + "vstm %1!, {d16-d23} \n" + "vstm %2!, {d24-d31} \n" - : "=r"(output0_tm), + : "=r"(nn), + "=r"(output0_tm), + "=r"(output1_tm), "=r"(r0), - "=r"(k0), - "=r"(nn) - : "0"(output0_tm), - "1"(r0), - "2"(k0), - "3"(nn) + "=r"(k0) + : "0"(nn), + "1"(output0_tm), + "2"(output1_tm), + "3"(r0), + "4"(k0) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); #endif } @@ -1391,10 +1436,11 @@ static void conv3x3s1_winograd42_pack8_int8_neon(const Mat& bottom_blob, Mat& to } vst1q_s32(output0_tm, _sum0); - vst1q_s32(output0_tm + 4, _sum1); - vst1q_s32(output0_tm + 8, _sum2); - vst1q_s32(output0_tm + 12, _sum3); - output0_tm += 16; + vst1q_s32(output1_tm, _sum1); + vst1q_s32(output0_tm + 4, _sum2); + vst1q_s32(output1_tm + 4, _sum3); + output0_tm += 8; + output1_tm += 8; } for (; i < tiles; i++) { @@ -1451,10 +1497,626 @@ static void conv3x3s1_winograd42_pack8_int8_neon(const Mat& bottom_blob, Mat& to k0 += 64; } + vst1q_s32(output0_tm, _sum0); + vst1q_s32(output1_tm, _sum1); + output0_tm += 4; + output1_tm += 4; + } + } + } + + remain_outch_start += nn_outch << 1; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = remain_outch_start; p < outch; p++) + { + int* output0_tm = top_blob_tm.channel(p); + + const Mat kernel0_tm = kernel_tm.channel(p / 2 + p % 2); + + for (int r = 0; r < 36; r++) + { + const Mat bb2 = bottom_blob_tm2.channel(r); + + int i = 0; +#if __aarch64__ + for (; i + 11 < tiles; i += 12) + { + const short* r0 = bb2.row(i / 12); + const short* k0 = kernel0_tm.row(r); + + int nn = inch; // inch always > 0 + + asm volatile( + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" // r01 + + "eor v8.16b, v8.16b, v8.16b \n" + "eor v9.16b, v9.16b, v9.16b \n" + + "ld1 {v4.8h, v5.8h}, [%3], #32 \n" // w01 + + "eor v10.16b, v10.16b, v10.16b \n" + "eor v11.16b, v11.16b, v11.16b \n" + + "prfm pldl1keep, [%2, #256] \n" + + "eor v12.16b, v12.16b, v12.16b \n" + "eor v13.16b, v13.16b, v13.16b \n" + + "prfm pldl1keep, [%3, #256] \n" + + "eor v14.16b, v14.16b, v14.16b \n" + "eor v15.16b, v15.16b, v15.16b \n" + "eor v16.16b, v16.16b, v16.16b \n" + "eor v17.16b, v17.16b, v17.16b \n" + "eor v18.16b, v18.16b, v18.16b \n" + "eor v19.16b, v19.16b, v19.16b \n" + + "0: \n" + + "smlal v8.4s, v4.4h, v0.h[0] \n" + "smlal v9.4s, v4.4h, v0.h[1] \n" + "smlal v10.4s, v4.4h, v0.h[2] \n" + "smlal v11.4s, v4.4h, v0.h[3] \n" + "smlal v12.4s, v4.4h, v0.h[4] \n" + "smlal v13.4s, v4.4h, v0.h[5] \n" + "smlal v14.4s, v4.4h, v0.h[6] \n" + "smlal v15.4s, v4.4h, v0.h[7] \n" + + "ld1 {v2.8h, v3.8h}, [%2], #32 \n" // r23 + + "smlal v16.4s, v4.4h, v1.h[0] \n" + "smlal v17.4s, v4.4h, v1.h[1] \n" + + "prfm pldl1keep, [%2, #256] \n" + + "smlal v18.4s, v4.4h, v1.h[2] \n" + "smlal v19.4s, v4.4h, v1.h[3] \n" + + "smlal2 v8.4s, v4.8h, v1.h[4] \n" + "smlal2 v9.4s, v4.8h, v1.h[5] \n" + "smlal2 v10.4s, v4.8h, v1.h[6] \n" + "smlal2 v11.4s, v4.8h, v1.h[7] \n" + "smlal2 v12.4s, v4.8h, v2.h[0] \n" + "smlal2 v13.4s, v4.8h, v2.h[1] \n" + "smlal2 v14.4s, v4.8h, v2.h[2] \n" + "smlal2 v15.4s, v4.8h, v2.h[3] \n" + "smlal2 v16.4s, v4.8h, v2.h[4] \n" + "smlal2 v17.4s, v4.8h, v2.h[5] \n" + "smlal2 v18.4s, v4.8h, v2.h[6] \n" + "smlal2 v19.4s, v4.8h, v2.h[7] \n" + + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" // r45 + + "smlal v8.4s, v5.4h, v3.h[0] \n" + "smlal v9.4s, v5.4h, v3.h[1] \n" + + "prfm pldl1keep, [%2, #256] \n" + + "smlal v10.4s, v5.4h, v3.h[2] \n" + "smlal v11.4s, v5.4h, v3.h[3] \n" + "smlal v12.4s, v5.4h, v3.h[4] \n" + "smlal v13.4s, v5.4h, v3.h[5] \n" + "smlal v14.4s, v5.4h, v3.h[6] \n" + "smlal v15.4s, v5.4h, v3.h[7] \n" + "smlal v16.4s, v5.4h, v0.h[0] \n" + "smlal v17.4s, v5.4h, v0.h[1] \n" + "smlal v18.4s, v5.4h, v0.h[2] \n" + "smlal v19.4s, v5.4h, v0.h[3] \n" + + "ld1 {v6.8h, v7.8h}, [%3], #32 \n" // w23 + + "smlal2 v8.4s, v5.8h, v0.h[4] \n" + "smlal2 v9.4s, v5.8h, v0.h[5] \n" + + "prfm pldl1keep, [%3, #256] \n" + + "smlal2 v10.4s, v5.8h, v0.h[6] \n" + "smlal2 v11.4s, v5.8h, v0.h[7] \n" + + "ld1 {v2.8h, v3.8h}, [%2], #32 \n" // r67 + + "smlal2 v12.4s, v5.8h, v1.h[0] \n" + "smlal2 v13.4s, v5.8h, v1.h[1] \n" + + "prfm pldl1keep, [%2, #256] \n" + + "smlal2 v14.4s, v5.8h, v1.h[2] \n" + "smlal2 v15.4s, v5.8h, v1.h[3] \n" + "smlal2 v16.4s, v5.8h, v1.h[4] \n" + "smlal2 v17.4s, v5.8h, v1.h[5] \n" + "smlal2 v18.4s, v5.8h, v1.h[6] \n" + "smlal2 v19.4s, v5.8h, v1.h[7] \n" + + "smlal v8.4s, v6.4h, v2.h[0] \n" + "smlal v9.4s, v6.4h, v2.h[1] \n" + "smlal v10.4s, v6.4h, v2.h[2] \n" + "smlal v11.4s, v6.4h, v2.h[3] \n" + "smlal v12.4s, v6.4h, v2.h[4] \n" + "smlal v13.4s, v6.4h, v2.h[5] \n" + "smlal v14.4s, v6.4h, v2.h[6] \n" + "smlal v15.4s, v6.4h, v2.h[7] \n" + + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" // r89 + + "smlal v16.4s, v6.4h, v3.h[0] \n" + "smlal v17.4s, v6.4h, v3.h[1] \n" + + "prfm pldl1keep, [%2, #256] \n" + + "smlal v18.4s, v6.4h, v3.h[2] \n" + "smlal v19.4s, v6.4h, v3.h[3] \n" + + "smlal2 v8.4s, v6.8h, v3.h[4] \n" + "smlal2 v9.4s, v6.8h, v3.h[5] \n" + "smlal2 v10.4s, v6.8h, v3.h[6] \n" + "smlal2 v11.4s, v6.8h, v3.h[7] \n" + "smlal2 v12.4s, v6.8h, v0.h[0] \n" + "smlal2 v13.4s, v6.8h, v0.h[1] \n" + "smlal2 v14.4s, v6.8h, v0.h[2] \n" + "smlal2 v15.4s, v6.8h, v0.h[3] \n" + "smlal2 v16.4s, v6.8h, v0.h[4] \n" + "smlal2 v17.4s, v6.8h, v0.h[5] \n" + "smlal2 v18.4s, v6.8h, v0.h[6] \n" + "smlal2 v19.4s, v6.8h, v0.h[7] \n" + + "ld1 {v2.8h, v3.8h}, [%2], #32 \n" // r1011 + + "smlal v8.4s, v7.4h, v1.h[0] \n" + "smlal v9.4s, v7.4h, v1.h[1] \n" + + "prfm pldl1keep, [%2, #256] \n" + + "smlal v10.4s, v7.4h, v1.h[2] \n" + "smlal v11.4s, v7.4h, v1.h[3] \n" + "smlal v12.4s, v7.4h, v1.h[4] \n" + "smlal v13.4s, v7.4h, v1.h[5] \n" + "smlal v14.4s, v7.4h, v1.h[6] \n" + "smlal v15.4s, v7.4h, v1.h[7] \n" + "smlal v16.4s, v7.4h, v2.h[0] \n" + "smlal v17.4s, v7.4h, v2.h[1] \n" + "smlal v18.4s, v7.4h, v2.h[2] \n" + "smlal v19.4s, v7.4h, v2.h[3] \n" + + "ld1 {v4.8h, v5.8h}, [%3], #32 \n" // w01 + + "smlal2 v8.4s, v7.8h, v2.h[4] \n" + "smlal2 v9.4s, v7.8h, v2.h[5] \n" + + "prfm pldl1keep, [%3, #256] \n" + + "smlal2 v10.4s, v7.8h, v2.h[6] \n" + "smlal2 v11.4s, v7.8h, v2.h[7] \n" + + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" // r01 + + "smlal2 v12.4s, v7.8h, v3.h[0] \n" + "smlal2 v13.4s, v7.8h, v3.h[1] \n" + + "prfm pldl1keep, [%2, #256] \n" + + "smlal2 v14.4s, v7.8h, v3.h[2] \n" + "smlal2 v15.4s, v7.8h, v3.h[3] \n" + "smlal2 v16.4s, v7.8h, v3.h[4] \n" + "smlal2 v17.4s, v7.8h, v3.h[5] \n" + + "subs %w0, %w0, #1 \n" + + "smlal2 v18.4s, v7.8h, v3.h[6] \n" + "smlal2 v19.4s, v7.8h, v3.h[7] \n" + + "bne 0b \n" + + "sub %2, %2, #32 \n" + "sub %3, %3, #32 \n" + + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%1], #64 \n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%1], #64 \n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%1], #64 \n" + + : "=r"(nn), // %0 + "=r"(output0_tm), // %1 + "=r"(r0), // %2 + "=r"(k0) // %3 + : "0"(nn), + "1"(output0_tm), + "2"(r0), + "3"(k0) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19"); + } + for (; i + 7 < tiles; i += 8) + { + const short* r0 = bb2.row(i / 12 + (i % 12) / 8); + const short* k0 = kernel0_tm.row(r); + + int nn = inch; // inch always > 0 + + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + int32x4_t _sum4 = vdupq_n_s32(0); + int32x4_t _sum5 = vdupq_n_s32(0); + int32x4_t _sum6 = vdupq_n_s32(0); + int32x4_t _sum7 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int16x8_t _val0 = vld1q_s16(r0); + int16x8_t _val1 = vld1q_s16(r0 + 8); + int16x8_t _val2 = vld1q_s16(r0 + 16); + int16x8_t _val3 = vld1q_s16(r0 + 24); + int16x8_t _val4 = vld1q_s16(r0 + 32); + int16x8_t _val5 = vld1q_s16(r0 + 40); + int16x8_t _val6 = vld1q_s16(r0 + 48); + int16x8_t _val7 = vld1q_s16(r0 + 56); + + int16x8_t _w0 = vld1q_s16(k0); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); + _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_w0), vget_low_s16(_val0), 1); + _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w0), vget_low_s16(_val0), 2); + _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_w0), vget_low_s16(_val0), 3); + _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w0), vget_high_s16(_val0), 0); + _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_w0), vget_high_s16(_val0), 1); + _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w0), vget_high_s16(_val0), 2); + _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_w0), vget_high_s16(_val0), 3); + + _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_w0), vget_low_s16(_val1), 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val1), 1); + _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_w0), vget_low_s16(_val1), 2); + _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w0), vget_low_s16(_val1), 3); + _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_w0), vget_high_s16(_val1), 0); + _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w0), vget_high_s16(_val1), 1); + _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_w0), vget_high_s16(_val1), 2); + _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w0), vget_high_s16(_val1), 3); + + int16x8_t _w1 = vld1q_s16(k0 + 8); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val2), 0); + _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_w1), vget_low_s16(_val2), 1); + _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w1), vget_low_s16(_val2), 2); + _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_w1), vget_low_s16(_val2), 3); + _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w1), vget_high_s16(_val2), 0); + _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_w1), vget_high_s16(_val2), 1); + _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w1), vget_high_s16(_val2), 2); + _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_w1), vget_high_s16(_val2), 3); + + _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_w1), vget_low_s16(_val3), 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val3), 1); + _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_w1), vget_low_s16(_val3), 2); + _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w1), vget_low_s16(_val3), 3); + _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_w1), vget_high_s16(_val3), 0); + _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w1), vget_high_s16(_val3), 1); + _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_w1), vget_high_s16(_val3), 2); + _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w1), vget_high_s16(_val3), 3); + + int16x8_t _w2 = vld1q_s16(k0 + 16); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_low_s16(_val4), 0); + _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_w2), vget_low_s16(_val4), 1); + _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w2), vget_low_s16(_val4), 2); + _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_w2), vget_low_s16(_val4), 3); + _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w2), vget_high_s16(_val4), 0); + _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_w2), vget_high_s16(_val4), 1); + _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w2), vget_high_s16(_val4), 2); + _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_w2), vget_high_s16(_val4), 3); + + _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_w2), vget_low_s16(_val5), 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_low_s16(_val5), 1); + _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_w2), vget_low_s16(_val5), 2); + _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w2), vget_low_s16(_val5), 3); + _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_w2), vget_high_s16(_val5), 0); + _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w2), vget_high_s16(_val5), 1); + _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_w2), vget_high_s16(_val5), 2); + _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w2), vget_high_s16(_val5), 3); + + int16x8_t _w3 = vld1q_s16(k0 + 24); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_low_s16(_val6), 0); + _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_w3), vget_low_s16(_val6), 1); + _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w3), vget_low_s16(_val6), 2); + _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_w3), vget_low_s16(_val6), 3); + _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w3), vget_high_s16(_val6), 0); + _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_w3), vget_high_s16(_val6), 1); + _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w3), vget_high_s16(_val6), 2); + _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_w3), vget_high_s16(_val6), 3); + + _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_w3), vget_low_s16(_val7), 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_low_s16(_val7), 1); + _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_w3), vget_low_s16(_val7), 2); + _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w3), vget_low_s16(_val7), 3); + _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_w3), vget_high_s16(_val7), 0); + _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w3), vget_high_s16(_val7), 1); + _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_w3), vget_high_s16(_val7), 2); + _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w3), vget_high_s16(_val7), 3); + + r0 += 64; + k0 += 32; + } + vst1q_s32(output0_tm, _sum0); vst1q_s32(output0_tm + 4, _sum1); + vst1q_s32(output0_tm + 8, _sum2); + vst1q_s32(output0_tm + 12, _sum3); + vst1q_s32(output0_tm + 16, _sum4); + vst1q_s32(output0_tm + 20, _sum5); + vst1q_s32(output0_tm + 24, _sum6); + vst1q_s32(output0_tm + 28, _sum7); + output0_tm += 32; + } +#endif // __aarch64__ + for (; i + 3 < tiles; i += 4) + { +#if __aarch64__ + const short* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4); +#else + const short* r0 = bb2.row(i / 4); +#endif + const short* k0 = kernel0_tm.row(r); + + int nn = inch; // inch always > 0 + +#if __aarch64__ + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + int32x4_t _sum4 = vdupq_n_s32(0); + int32x4_t _sum5 = vdupq_n_s32(0); + int32x4_t _sum6 = vdupq_n_s32(0); + int32x4_t _sum7 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int16x8_t _val0 = vld1q_s16(r0); + int16x8_t _val1 = vld1q_s16(r0 + 8); + int16x8_t _val2 = vld1q_s16(r0 + 16); + int16x8_t _val3 = vld1q_s16(r0 + 24); + + int16x8_t _w0 = vld1q_s16(k0); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 1); + _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w0), vget_low_s16(_val1), 0); + _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w0), vget_low_s16(_val1), 1); + _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w0), vget_low_s16(_val2), 0); + _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w0), vget_low_s16(_val2), 1); + _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w0), vget_low_s16(_val3), 0); + _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w0), vget_low_s16(_val3), 1); + + int16x8_t _w1 = vld1q_s16(k0 + 8); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val0), 2); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val0), 3); + _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w1), vget_low_s16(_val1), 2); + _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w1), vget_low_s16(_val1), 3); + _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w1), vget_low_s16(_val2), 2); + _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w1), vget_low_s16(_val2), 3); + _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w1), vget_low_s16(_val3), 2); + _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w1), vget_low_s16(_val3), 3); + + int16x8_t _w2 = vld1q_s16(k0 + 16); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_high_s16(_val0), 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_high_s16(_val0), 1); + _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w2), vget_high_s16(_val1), 0); + _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w2), vget_high_s16(_val1), 1); + _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w2), vget_high_s16(_val2), 0); + _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w2), vget_high_s16(_val2), 1); + _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w2), vget_high_s16(_val3), 0); + _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w2), vget_high_s16(_val3), 1); + + int16x8_t _w3 = vld1q_s16(k0 + 24); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_high_s16(_val0), 2); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_high_s16(_val0), 3); + _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w3), vget_high_s16(_val1), 2); + _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w3), vget_high_s16(_val1), 3); + _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w3), vget_high_s16(_val2), 2); + _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w3), vget_high_s16(_val2), 3); + _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w3), vget_high_s16(_val3), 2); + _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w3), vget_high_s16(_val3), 3); + + r0 += 32; + k0 += 32; + } + + _sum0 = vaddq_s32(_sum0, _sum1); + _sum2 = vaddq_s32(_sum2, _sum3); + _sum4 = vaddq_s32(_sum4, _sum5); + _sum6 = vaddq_s32(_sum6, _sum7); + + vst1q_s32(output0_tm, _sum0); + vst1q_s32(output0_tm + 4, _sum2); + vst1q_s32(output0_tm + 8, _sum4); + vst1q_s32(output0_tm + 12, _sum6); + output0_tm += 16; +#else + asm volatile( + "veor q8, q8 \n" + "veor q9, q9 \n" + "veor q10, q10 \n" + "veor q11, q11 \n" + "veor q12, q12 \n" + "veor q13, q13 \n" + "veor q14, q14 \n" + "veor q15, q15 \n" + + "0: \n" + + "pld [%2, #256] \n" + "pld [%2, #512] \n" + "vldm %2!, {d0-d7} \n" + + "pld [%3, #256] \n" + "vld1.s16 {d8-d11}, [%3 :128]! \n" + + "vmlal.s16 q8, d8, d0[0] \n" + "vmlal.s16 q12, d9, d0[1] \n" + "vmlal.s16 q9, d8, d2[0] \n" + "vmlal.s16 q13, d9, d2[1] \n" + "vmlal.s16 q10, d8, d4[0] \n" + "vmlal.s16 q14, d9, d4[1] \n" + "vmlal.s16 q11, d8, d6[0] \n" + "vmlal.s16 q15, d9, d6[1] \n" + + "pld [%3, #128] \n" + "vld1.s16 {d8-d9}, [%3 :128]! \n" + + "vmlal.s16 q8, d10, d0[2] \n" + "vmlal.s16 q12, d11, d0[3] \n" + "vmlal.s16 q9, d10, d2[2] \n" + "vmlal.s16 q13, d11, d2[3] \n" + "vmlal.s16 q10, d10, d4[2] \n" + "vmlal.s16 q14, d11, d4[3] \n" + "vmlal.s16 q11, d10, d6[2] \n" + "vmlal.s16 q15, d11, d6[3] \n" + + "pld [%3, #128] \n" + "vld1.s16 {d10-d11}, [%3 :128]! \n" + + "vmlal.s16 q8, d8, d1[0] \n" + "vmlal.s16 q12, d9, d1[1] \n" + "vmlal.s16 q9, d8, d3[0] \n" + "vmlal.s16 q13, d9, d3[1] \n" + "vmlal.s16 q10, d8, d5[0] \n" + "vmlal.s16 q14, d9, d5[1] \n" + "vmlal.s16 q11, d8, d7[0] \n" + "vmlal.s16 q15, d9, d7[1] \n" + + "subs %0, %0, #1 \n" + + "vmlal.s16 q8, d10, d1[2] \n" + "vmlal.s16 q12, d11, d1[3] \n" + "vmlal.s16 q9, d10, d3[2] \n" + "vmlal.s16 q13, d11, d3[3] \n" + "vmlal.s16 q10, d10, d5[2] \n" + "vmlal.s16 q14, d11, d5[3] \n" + "vmlal.s16 q11, d10, d7[2] \n" + "vmlal.s16 q15, d11, d7[3] \n" + + "bne 0b \n" + + "vadd.s32 q8, q8, q12 \n" + "vadd.s32 q9, q9, q13 \n" + "vadd.s32 q10, q10, q14 \n" + "vadd.s32 q11, q11, q15 \n" + + "vstm %1!, {d16-d23} \n" + + : "=r"(nn), + "=r"(output0_tm), + "=r"(r0), + "=r"(k0) + : "0"(nn), + "1"(output0_tm), + "2"(r0), + "3"(k0) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif + } + for (; i + 1 < tiles; i += 2) + { +#if __aarch64__ + const short* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2); +#else + const short* r0 = bb2.row(i / 4 + (i % 4) / 2); +#endif + const short* k0 = kernel0_tm.row(r); + + int nn = inch; // inch always > 0 + + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int16x8_t _val0 = vld1q_s16(r0); + int16x8_t _val1 = vld1q_s16(r0 + 8); + + int16x8_t _w0 = vld1q_s16(k0); + int16x8_t _w1 = vld1q_s16(k0 + 8); + int16x8_t _w2 = vld1q_s16(k0 + 16); + int16x8_t _w3 = vld1q_s16(k0 + 24); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 1); + _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w0), vget_low_s16(_val1), 0); + _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w0), vget_low_s16(_val1), 1); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val0), 2); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val0), 3); + _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w1), vget_low_s16(_val1), 2); + _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w1), vget_low_s16(_val1), 3); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_high_s16(_val0), 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_high_s16(_val0), 1); + _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w2), vget_high_s16(_val1), 0); + _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w2), vget_high_s16(_val1), 1); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_high_s16(_val0), 2); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_high_s16(_val0), 3); + _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w3), vget_high_s16(_val1), 2); + _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w3), vget_high_s16(_val1), 3); + + r0 += 16; + k0 += 32; + } + + _sum0 = vaddq_s32(_sum0, _sum1); + _sum2 = vaddq_s32(_sum2, _sum3); + + vst1q_s32(output0_tm, _sum0); + vst1q_s32(output0_tm + 4, _sum2); output0_tm += 8; } + for (; i < tiles; i++) + { +#if __aarch64__ + const short* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2); +#else + const short* r0 = bb2.row(i / 4 + (i % 4) / 2 + i % 2); +#endif + const short* k0 = kernel0_tm.row(r); + + int nn = inch; // inch always > 0 + + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int16x8_t _val0 = vld1q_s16(r0); + + int16x8_t _w0 = vld1q_s16(k0); + int16x8_t _w1 = vld1q_s16(k0 + 8); + int16x8_t _w2 = vld1q_s16(k0 + 16); + int16x8_t _w3 = vld1q_s16(k0 + 24); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 1); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val0), 2); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val0), 3); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_high_s16(_val0), 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_high_s16(_val0), 1); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_high_s16(_val0), 2); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_high_s16(_val0), 3); + + r0 += 8; + k0 += 32; + } + + _sum0 = vaddq_s32(_sum0, _sum1); + + vst1q_s32(output0_tm, _sum0); + output0_tm += 4; + } } } } @@ -1469,7 +2131,7 @@ static void conv3x3s1_winograd42_pack8_int8_neon(const Mat& bottom_blob, Mat& to } else { - top_blob_bordered.create(outw, outh, outch, 4u * elempack, elempack, opt.workspace_allocator); + top_blob_bordered.create(outw, outh, outch, 4u * 4, 4, opt.workspace_allocator); } { // const float otm[4][6] = { @@ -1494,7 +2156,7 @@ static void conv3x3s1_winograd42_pack8_int8_neon(const Mat& bottom_blob, Mat& to const Mat out0_tm = top_blob_tm.channel(p); Mat out0 = top_blob_bordered.channel(p); - int tmp[4][6][8]; + int tmp[4][6][4]; // tile for (int i = 0; i < outh / 4; i++) @@ -1503,193 +2165,131 @@ static void conv3x3s1_winograd42_pack8_int8_neon(const Mat& bottom_blob, Mat& to { // top_blob_tm.create(tiles, 36, outch, elemsize, elempack); - const int* output0_tm_0 = (const int*)out0_tm + (i * w_tm / 6 + j) * 8; - const int* output0_tm_1 = output0_tm_0 + tiles * 8; - const int* output0_tm_2 = output0_tm_0 + tiles * 16; - const int* output0_tm_3 = output0_tm_0 + tiles * 24; - const int* output0_tm_4 = output0_tm_0 + tiles * 32; - const int* output0_tm_5 = output0_tm_0 + tiles * 40; + const int* output0_tm_0 = (const int*)out0_tm + (i * w_tm / 6 + j) * 4; + const int* output0_tm_1 = output0_tm_0 + tiles * 4; + const int* output0_tm_2 = output0_tm_0 + tiles * 8; + const int* output0_tm_3 = output0_tm_0 + tiles * 12; + const int* output0_tm_4 = output0_tm_0 + tiles * 16; + const int* output0_tm_5 = output0_tm_0 + tiles * 20; - int* output0 = out0.row(i * 4) + (j * 4) * 8; + int* output0 = out0.row(i * 4) + (j * 4) * 4; // TODO neon optimize for (int m = 0; m < 5; m++) { - int32x4_t _out0tm0_low = vld1q_s32(output0_tm_0); - int32x4_t _out0tm0_high = vld1q_s32(output0_tm_0 + 4); - int32x4_t _out0tm1_low = vld1q_s32(output0_tm_1); - int32x4_t _out0tm1_high = vld1q_s32(output0_tm_1 + 4); - int32x4_t _out0tm2_low = vld1q_s32(output0_tm_2); - int32x4_t _out0tm2_high = vld1q_s32(output0_tm_2 + 4); - int32x4_t _out0tm3_low = vld1q_s32(output0_tm_3); - int32x4_t _out0tm3_high = vld1q_s32(output0_tm_3 + 4); - int32x4_t _out0tm4_low = vld1q_s32(output0_tm_4); - int32x4_t _out0tm4_high = vld1q_s32(output0_tm_4 + 4); - int32x4_t _out0tm5_low = vld1q_s32(output0_tm_5); - int32x4_t _out0tm5_high = vld1q_s32(output0_tm_5 + 4); - - int32x4_t _tmp02a_low = vaddq_s32(_out0tm1_low, _out0tm2_low); - int32x4_t _tmp02a_high = vaddq_s32(_out0tm1_high, _out0tm2_high); - int32x4_t _tmp13a_low = vsubq_s32(_out0tm1_low, _out0tm2_low); - int32x4_t _tmp13a_high = vsubq_s32(_out0tm1_high, _out0tm2_high); - - int32x4_t _tmp02b_low = vaddq_s32(_out0tm3_low, _out0tm4_low); - int32x4_t _tmp02b_high = vaddq_s32(_out0tm3_high, _out0tm4_high); - int32x4_t _tmp13b_low = vsubq_s32(_out0tm3_low, _out0tm4_low); - int32x4_t _tmp13b_high = vsubq_s32(_out0tm3_high, _out0tm4_high); + int32x4_t _out0tm0 = vld1q_s32(output0_tm_0); + int32x4_t _out0tm1 = vld1q_s32(output0_tm_1); + int32x4_t _out0tm2 = vld1q_s32(output0_tm_2); + int32x4_t _out0tm3 = vld1q_s32(output0_tm_3); + int32x4_t _out0tm4 = vld1q_s32(output0_tm_4); + int32x4_t _out0tm5 = vld1q_s32(output0_tm_5); + + int32x4_t _tmp02a = vaddq_s32(_out0tm1, _out0tm2); + int32x4_t _tmp13a = vsubq_s32(_out0tm1, _out0tm2); + + int32x4_t _tmp02b = vaddq_s32(_out0tm3, _out0tm4); + int32x4_t _tmp13b = vsubq_s32(_out0tm3, _out0tm4); int32x4_t _v2 = vdupq_n_s32(2); int32x4_t _v4 = vdupq_n_s32(4); int32x4_t _v8 = vdupq_n_s32(8); - int32x4_t _tmp0m_low = vaddq_s32(vaddq_s32(_out0tm0_low, _tmp02a_low), _tmp02b_low); - int32x4_t _tmp0m_high = vaddq_s32(vaddq_s32(_out0tm0_high, _tmp02a_high), _tmp02b_high); - int32x4_t _tmp1m_low = vmlaq_s32(_tmp13a_low, _tmp13b_low, _v2); - int32x4_t _tmp1m_high = vmlaq_s32(_tmp13a_high, _tmp13b_high, _v2); - int32x4_t _tmp2m_low = vmlaq_s32(_tmp02a_low, _tmp02b_low, _v4); - int32x4_t _tmp2m_high = vmlaq_s32(_tmp02a_high, _tmp02b_high, _v4); - int32x4_t _tmp3m_low = vmlaq_s32(vmlaq_s32(_tmp13a_low, _out0tm5_low, _v4), _tmp13b_low, _v8); - int32x4_t _tmp3m_high = vmlaq_s32(vmlaq_s32(_tmp13a_high, _out0tm5_high, _v4), _tmp13b_high, _v8); - - vst1q_s32(tmp[0][m], _tmp0m_low); - vst1q_s32(tmp[0][m] + 4, _tmp0m_high); - vst1q_s32(tmp[1][m], _tmp1m_low); - vst1q_s32(tmp[1][m] + 4, _tmp1m_high); - vst1q_s32(tmp[2][m], _tmp2m_low); - vst1q_s32(tmp[2][m] + 4, _tmp2m_high); - vst1q_s32(tmp[3][m], _tmp3m_low); - vst1q_s32(tmp[3][m] + 4, _tmp3m_high); - - output0_tm_0 += tiles * 48; - output0_tm_1 += tiles * 48; - output0_tm_2 += tiles * 48; - output0_tm_3 += tiles * 48; - output0_tm_4 += tiles * 48; - output0_tm_5 += tiles * 48; + int32x4_t _tmp0m = vaddq_s32(vaddq_s32(_out0tm0, _tmp02a), _tmp02b); + int32x4_t _tmp1m = vmlaq_s32(_tmp13a, _tmp13b, _v2); + int32x4_t _tmp2m = vmlaq_s32(_tmp02a, _tmp02b, _v4); + int32x4_t _tmp3m = vmlaq_s32(vmlaq_s32(_tmp13a, _out0tm5, _v4), _tmp13b, _v8); + + vst1q_s32(tmp[0][m], _tmp0m); + vst1q_s32(tmp[1][m], _tmp1m); + vst1q_s32(tmp[2][m], _tmp2m); + vst1q_s32(tmp[3][m], _tmp3m); + + output0_tm_0 += tiles * 24; + output0_tm_1 += tiles * 24; + output0_tm_2 += tiles * 24; + output0_tm_3 += tiles * 24; + output0_tm_4 += tiles * 24; + output0_tm_5 += tiles * 24; } for (int m = 5; m < 6; m++) { - int32x4_t _out0tm0_low = vld1q_s32(output0_tm_0); - int32x4_t _out0tm0_high = vld1q_s32(output0_tm_0 + 4); - int32x4_t _out0tm1_low = vld1q_s32(output0_tm_1); - int32x4_t _out0tm1_high = vld1q_s32(output0_tm_1 + 4); - int32x4_t _out0tm2_low = vld1q_s32(output0_tm_2); - int32x4_t _out0tm2_high = vld1q_s32(output0_tm_2 + 4); - int32x4_t _out0tm3_low = vld1q_s32(output0_tm_3); - int32x4_t _out0tm3_high = vld1q_s32(output0_tm_3 + 4); - int32x4_t _out0tm4_low = vld1q_s32(output0_tm_4); - int32x4_t _out0tm4_high = vld1q_s32(output0_tm_4 + 4); - int32x4_t _out0tm5_low = vld1q_s32(output0_tm_5); - int32x4_t _out0tm5_high = vld1q_s32(output0_tm_5 + 4); - - int32x4_t _tmp02a_low = vaddq_s32(_out0tm1_low, _out0tm2_low); - int32x4_t _tmp02a_high = vaddq_s32(_out0tm1_high, _out0tm2_high); - int32x4_t _tmp13a_low = vsubq_s32(_out0tm1_low, _out0tm2_low); - int32x4_t _tmp13a_high = vsubq_s32(_out0tm1_high, _out0tm2_high); - - int32x4_t _tmp02b_low = vaddq_s32(_out0tm3_low, _out0tm4_low); - int32x4_t _tmp02b_high = vaddq_s32(_out0tm3_high, _out0tm4_high); - int32x4_t _tmp13b_low = vsubq_s32(_out0tm3_low, _out0tm4_low); - int32x4_t _tmp13b_high = vsubq_s32(_out0tm3_high, _out0tm4_high); + int32x4_t _out0tm0 = vld1q_s32(output0_tm_0); + int32x4_t _out0tm1 = vld1q_s32(output0_tm_1); + int32x4_t _out0tm2 = vld1q_s32(output0_tm_2); + int32x4_t _out0tm3 = vld1q_s32(output0_tm_3); + int32x4_t _out0tm4 = vld1q_s32(output0_tm_4); + int32x4_t _out0tm5 = vld1q_s32(output0_tm_5); + + int32x4_t _tmp02a = vaddq_s32(_out0tm1, _out0tm2); + int32x4_t _tmp13a = vsubq_s32(_out0tm1, _out0tm2); + + int32x4_t _tmp02b = vaddq_s32(_out0tm3, _out0tm4); + int32x4_t _tmp13b = vsubq_s32(_out0tm3, _out0tm4); int32x4_t _v2 = vdupq_n_s32(2); int32x4_t _v4 = vdupq_n_s32(4); int32x4_t _v8 = vdupq_n_s32(8); - int32x4_t _tmp0m_low = vaddq_s32(vaddq_s32(_out0tm0_low, _tmp02a_low), _tmp02b_low); - int32x4_t _tmp0m_high = vaddq_s32(vaddq_s32(_out0tm0_high, _tmp02a_high), _tmp02b_high); - int32x4_t _tmp1m_low = vmlaq_s32(_tmp13a_low, _tmp13b_low, _v2); - int32x4_t _tmp1m_high = vmlaq_s32(_tmp13a_high, _tmp13b_high, _v2); - int32x4_t _tmp2m_low = vmlaq_s32(_tmp02a_low, _tmp02b_low, _v4); - int32x4_t _tmp2m_high = vmlaq_s32(_tmp02a_high, _tmp02b_high, _v4); - int32x4_t _tmp3m_low = vmlaq_s32(vmlaq_s32(_tmp13a_low, _out0tm5_low, _v4), _tmp13b_low, _v8); - int32x4_t _tmp3m_high = vmlaq_s32(vmlaq_s32(_tmp13a_high, _out0tm5_high, _v4), _tmp13b_high, _v8); - - _tmp0m_low = vmulq_s32(_tmp0m_low, _v4); - _tmp0m_high = vmulq_s32(_tmp0m_high, _v4); - _tmp1m_low = vmulq_s32(_tmp1m_low, _v4); - _tmp1m_high = vmulq_s32(_tmp1m_high, _v4); - _tmp2m_low = vmulq_s32(_tmp2m_low, _v4); - _tmp2m_high = vmulq_s32(_tmp2m_high, _v4); - _tmp3m_low = vmulq_s32(_tmp3m_low, _v4); - _tmp3m_high = vmulq_s32(_tmp3m_high, _v4); - - vst1q_s32(tmp[0][m], _tmp0m_low); - vst1q_s32(tmp[0][m] + 4, _tmp0m_high); - vst1q_s32(tmp[1][m], _tmp1m_low); - vst1q_s32(tmp[1][m] + 4, _tmp1m_high); - vst1q_s32(tmp[2][m], _tmp2m_low); - vst1q_s32(tmp[2][m] + 4, _tmp2m_high); - vst1q_s32(tmp[3][m], _tmp3m_low); - vst1q_s32(tmp[3][m] + 4, _tmp3m_high); - - output0_tm_0 += tiles * 48; - output0_tm_1 += tiles * 48; - output0_tm_2 += tiles * 48; - output0_tm_3 += tiles * 48; - output0_tm_4 += tiles * 48; - output0_tm_5 += tiles * 48; + int32x4_t _tmp0m = vaddq_s32(vaddq_s32(_out0tm0, _tmp02a), _tmp02b); + int32x4_t _tmp1m = vmlaq_s32(_tmp13a, _tmp13b, _v2); + int32x4_t _tmp2m = vmlaq_s32(_tmp02a, _tmp02b, _v4); + int32x4_t _tmp3m = vmlaq_s32(vmlaq_s32(_tmp13a, _out0tm5, _v4), _tmp13b, _v8); + + _tmp0m = vmulq_s32(_tmp0m, _v4); + _tmp1m = vmulq_s32(_tmp1m, _v4); + _tmp2m = vmulq_s32(_tmp2m, _v4); + _tmp3m = vmulq_s32(_tmp3m, _v4); + + vst1q_s32(tmp[0][m], _tmp0m); + vst1q_s32(tmp[1][m], _tmp1m); + vst1q_s32(tmp[2][m], _tmp2m); + vst1q_s32(tmp[3][m], _tmp3m); + + output0_tm_0 += tiles * 24; + output0_tm_1 += tiles * 24; + output0_tm_2 += tiles * 24; + output0_tm_3 += tiles * 24; + output0_tm_4 += tiles * 24; + output0_tm_5 += tiles * 24; } for (int m = 0; m < 4; m++) { - int32x4_t _tmp00_low = vld1q_s32(tmp[m][0]); - int32x4_t _tmp00_high = vld1q_s32(tmp[m][0] + 4); - int32x4_t _tmp01_low = vld1q_s32(tmp[m][1]); - int32x4_t _tmp01_high = vld1q_s32(tmp[m][1] + 4); - int32x4_t _tmp02_low = vld1q_s32(tmp[m][2]); - int32x4_t _tmp02_high = vld1q_s32(tmp[m][2] + 4); - int32x4_t _tmp03_low = vld1q_s32(tmp[m][3]); - int32x4_t _tmp03_high = vld1q_s32(tmp[m][3] + 4); - int32x4_t _tmp04_low = vld1q_s32(tmp[m][4]); - int32x4_t _tmp04_high = vld1q_s32(tmp[m][4] + 4); - int32x4_t _tmp05_low = vld1q_s32(tmp[m][5]); - int32x4_t _tmp05_high = vld1q_s32(tmp[m][5] + 4); - - int32x4_t _tmp02a_low = vaddq_s32(_tmp01_low, _tmp02_low); - int32x4_t _tmp02a_high = vaddq_s32(_tmp01_high, _tmp02_high); - int32x4_t _tmp13a_low = vsubq_s32(_tmp01_low, _tmp02_low); - int32x4_t _tmp13a_high = vsubq_s32(_tmp01_high, _tmp02_high); - - int32x4_t _tmp02b_low = vaddq_s32(_tmp03_low, _tmp04_low); - int32x4_t _tmp02b_high = vaddq_s32(_tmp03_high, _tmp04_high); - int32x4_t _tmp13b_low = vsubq_s32(_tmp03_low, _tmp04_low); - int32x4_t _tmp13b_high = vsubq_s32(_tmp03_high, _tmp04_high); + int32x4_t _tmp00 = vld1q_s32(tmp[m][0]); + int32x4_t _tmp01 = vld1q_s32(tmp[m][1]); + int32x4_t _tmp02 = vld1q_s32(tmp[m][2]); + int32x4_t _tmp03 = vld1q_s32(tmp[m][3]); + int32x4_t _tmp04 = vld1q_s32(tmp[m][4]); + int32x4_t _tmp05 = vld1q_s32(tmp[m][5]); + + int32x4_t _tmp02a = vaddq_s32(_tmp01, _tmp02); + int32x4_t _tmp13a = vsubq_s32(_tmp01, _tmp02); + + int32x4_t _tmp02b = vaddq_s32(_tmp03, _tmp04); + int32x4_t _tmp13b = vsubq_s32(_tmp03, _tmp04); int32x4_t _v2 = vdupq_n_s32(2); int32x4_t _v4 = vdupq_n_s32(4); int32x4_t _v8 = vdupq_n_s32(8); - int32x4_t _out00_low = vaddq_s32(vaddq_s32(_tmp00_low, _tmp02a_low), _tmp02b_low); - int32x4_t _out00_high = vaddq_s32(vaddq_s32(_tmp00_high, _tmp02a_high), _tmp02b_high); - int32x4_t _out01_low = vmlaq_s32(_tmp13a_low, _tmp13b_low, _v2); - int32x4_t _out01_high = vmlaq_s32(_tmp13a_high, _tmp13b_high, _v2); - int32x4_t _out02_low = vmlaq_s32(_tmp02a_low, _tmp02b_low, _v4); - int32x4_t _out02_high = vmlaq_s32(_tmp02a_high, _tmp02b_high, _v4); - int32x4_t _out03_low = vmlaq_s32(vaddq_s32(_tmp05_low, _tmp13a_low), _tmp13b_low, _v8); - int32x4_t _out03_high = vmlaq_s32(vaddq_s32(_tmp05_high, _tmp13a_high), _tmp13b_high, _v8); + int32x4_t _out00 = vaddq_s32(vaddq_s32(_tmp00, _tmp02a), _tmp02b); + int32x4_t _out01 = vmlaq_s32(_tmp13a, _tmp13b, _v2); + int32x4_t _out02 = vmlaq_s32(_tmp02a, _tmp02b, _v4); + int32x4_t _out03 = vmlaq_s32(vaddq_s32(_tmp05, _tmp13a), _tmp13b, _v8); // TODO use integer trick for division by 576 float32x4_t _v576 = vdupq_n_f32(1.0 / 576); - _out00_low = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out00_low), _v576)); - _out00_high = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out00_high), _v576)); - _out01_low = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out01_low), _v576)); - _out01_high = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out01_high), _v576)); - _out02_low = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out02_low), _v576)); - _out02_high = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out02_high), _v576)); - _out03_low = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out03_low), _v576)); - _out03_high = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out03_high), _v576)); - - vst1q_s32(output0, _out00_low); - vst1q_s32(output0 + 4, _out00_high); - vst1q_s32(output0 + 8, _out01_low); - vst1q_s32(output0 + 12, _out01_high); - vst1q_s32(output0 + 16, _out02_low); - vst1q_s32(output0 + 20, _out02_high); - vst1q_s32(output0 + 24, _out03_low); - vst1q_s32(output0 + 28, _out03_high); - - output0 += outw * 8; + _out00 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out00), _v576)); + _out01 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out01), _v576)); + _out02 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out02), _v576)); + _out03 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out03), _v576)); + + vst1q_s32(output0, _out00); + vst1q_s32(output0 + 4, _out01); + vst1q_s32(output0 + 8, _out02); + vst1q_s32(output0 + 12, _out03); + + output0 += outw * 4; } } } diff --git a/src/layer/arm/convolution_7x7_pack1to8_int8.h b/src/layer/arm/convolution_7x7_pack1to4_int8.h similarity index 95% rename from src/layer/arm/convolution_7x7_pack1to8_int8.h rename to src/layer/arm/convolution_7x7_pack1to4_int8.h index 38b4d50f2..47a6d4c9f 100644 --- a/src/layer/arm/convolution_7x7_pack1to8_int8.h +++ b/src/layer/arm/convolution_7x7_pack1to4_int8.h @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -static void conv7x7s2_pack1to8_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +static void conv7x7s2_pack1to4_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) { int w = bottom_blob.w; int inch = bottom_blob.c; @@ -76,5 +76,5 @@ static void conv7x7s2_pack1to8_int8_neon(const Mat& bottom_blob, Mat& top_blob, } } - im2col_sgemm_pack1to8_int8_neon(bottom_im2col, top_blob, kernel, opt); + im2col_sgemm_pack1to4_int8_neon(bottom_im2col, top_blob, kernel, opt); } diff --git a/src/layer/arm/convolution_arm.cpp b/src/layer/arm/convolution_arm.cpp index 95ca29533..e8435f10b 100644 --- a/src/layer/arm/convolution_arm.cpp +++ b/src/layer/arm/convolution_arm.cpp @@ -70,18 +70,18 @@ namespace ncnn { #include "convolution_7x7_pack1to4_bf16s.h" #if NCNN_INT8 -#include "convolution_pack8_int8.h" -#include "convolution_pack1to8_int8.h" +#include "convolution_pack8to4_int8.h" +#include "convolution_pack1to4_int8.h" #include "convolution_pack8to1_int8.h" -#include "convolution_sgemm_pack8_int8.h" -#include "convolution_sgemm_pack1to8_int8.h" +#include "convolution_sgemm_pack8to4_int8.h" +#include "convolution_sgemm_pack1to4_int8.h" #include "convolution_sgemm_pack8to1_int8.h" -#include "convolution_1x1_pack8_int8.h" -#include "convolution_1x1_pack1to8_int8.h" +#include "convolution_1x1_pack8to4_int8.h" +#include "convolution_1x1_pack1to4_int8.h" #include "convolution_1x1_pack8to1_int8.h" -#include "convolution_3x3_pack8_int8.h" -#include "convolution_3x3_pack1to8_int8.h" -#include "convolution_7x7_pack1to8_int8.h" +#include "convolution_3x3_pack8to4_int8.h" +#include "convolution_3x3_pack1to4_int8.h" +#include "convolution_7x7_pack1to4_int8.h" #include "convolution_3x3_pack8to1_int8.h" #endif // NCNN_INT8 @@ -1787,7 +1787,7 @@ int Convolution_arm::create_pipeline_int8_arm(const Option& opt) if (opt.use_packing_layout) { elempack = num_input % 8 == 0 ? 8 : 1; - out_elempack = num_output % 8 == 0 ? 8 : 1; + out_elempack = num_output % 4 == 0 ? 4 : 1; } #endif // __ARM_NEON @@ -1855,15 +1855,15 @@ int Convolution_arm::create_pipeline_int8_arm(const Option& opt) } #if __ARM_NEON - if (elempack == 8 && out_elempack == 8) + if (elempack == 8 && out_elempack == 4) { if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - convolution_im2col_sgemm_transform_kernel_pack8_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); + convolution_im2col_sgemm_transform_kernel_pack8to4_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); } else if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) { - convolution_im2col_sgemm_transform_kernel_pack8_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); + convolution_im2col_sgemm_transform_kernel_pack8to4_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); } #if __ARM_FEATURE_DOTPROD else if (opt.use_winograd_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1 && num_input >= 256 && num_output >= 256) @@ -1871,39 +1871,39 @@ int Convolution_arm::create_pipeline_int8_arm(const Option& opt) else if (opt.use_winograd_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) #endif { - conv3x3s1_winograd42_transform_kernel_pack8_int8_neon(weight_data, weight_data_int8, num_input, num_output); + conv3x3s1_winograd42_transform_kernel_pack8to4_int8_neon(weight_data, weight_data_int8, num_input, num_output); } else if (opt.use_sgemm_convolution) { - convolution_im2col_sgemm_transform_kernel_pack8_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); + convolution_im2col_sgemm_transform_kernel_pack8to4_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); } } - if (elempack == 1 && out_elempack == 8) + if (elempack == 1 && out_elempack == 4) { if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - convolution_im2col_sgemm_transform_kernel_pack1to8_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); + convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); } else if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) { - convolution_im2col_sgemm_transform_kernel_pack1to8_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); + convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); } else if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - convolution_im2col_sgemm_transform_kernel_pack1to8_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); + convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); } else if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) { - convolution_im2col_sgemm_transform_kernel_pack1to8_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); + convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); } else if (kernel_w == 7 && kernel_h == 7 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) { - convolution_im2col_sgemm_transform_kernel_pack1to8_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); + convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); } else if (opt.use_sgemm_convolution) // TODO better condition && num_input >= 8 && num_output >= 8) { - convolution_im2col_sgemm_transform_kernel_pack1to8_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); + convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); } } @@ -1966,7 +1966,7 @@ int Convolution_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_blob, con #if __ARM_NEON if (opt.use_packing_layout) { - out_elempack = num_output % 8 == 0 ? 8 : 1; + out_elempack = num_output % 4 == 0 ? 4 : 1; } #endif // __ARM_NEON bool use_int8_requantize = int8_scale_term > 100; @@ -1988,26 +1988,21 @@ int Convolution_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_blob, con const int num_input = channels * elempack; + Mat top_blob_int32; + top_blob_int32.create(outw, outh, num_output / out_elempack, (size_t)(4u * out_elempack), out_elempack, opt.workspace_allocator); + if (top_blob_int32.empty()) + return -100; + #if __ARM_NEON - if (elempack == 8 && out_elempack == 8) + if (elempack == 8 && out_elempack == 4) { - Mat top_blob_int32; - if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - top_blob_int32.create(outw, outh, num_output / 4, (size_t)(4u * 4), 4, opt.workspace_allocator); - if (top_blob_int32.empty()) - return -100; - - conv1x1s1_sgemm_pack8_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); + conv1x1s1_sgemm_pack8to4_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); } else if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) { - top_blob_int32.create(outw, outh, num_output / 4, (size_t)(4u * 4), 4, opt.workspace_allocator); - if (top_blob_int32.empty()) - return -100; - - conv1x1s2_pack8_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); + conv1x1s2_pack8to4_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); } #if __ARM_FEATURE_DOTPROD else if (opt.use_winograd_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1 && num_input >= 256 && num_output >= 256) @@ -2015,27 +2010,15 @@ int Convolution_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_blob, con else if (opt.use_winograd_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) #endif { - top_blob_int32.create(outw, outh, num_output / out_elempack, (size_t)(4u * out_elempack), out_elempack, opt.workspace_allocator); - if (top_blob_int32.empty()) - return -100; - - conv3x3s1_winograd42_pack8_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); + conv3x3s1_winograd42_pack8to4_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); } else if (opt.use_sgemm_convolution) { - top_blob_int32.create(outw, outh, num_output / 4, (size_t)(4u * 4), 4, opt.workspace_allocator); - if (top_blob_int32.empty()) - return -100; - - convolution_im2col_sgemm_pack8_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); + convolution_im2col_sgemm_pack8to4_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); } else { - top_blob_int32.create(outw, outh, num_output / out_elempack, (size_t)(4u * out_elempack), out_elempack, opt.workspace_allocator); - if (top_blob_int32.empty()) - return -100; - - convolution_pack8_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); + convolution_pack8to4_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); } Mat scale_in_data(num_output); @@ -2066,65 +2049,35 @@ int Convolution_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_blob, con } } - if (elempack == 1 && out_elempack == 8) + if (elempack == 1 && out_elempack == 4) { - Mat top_blob_int32; - if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - top_blob_int32.create(outw, outh, num_output / 4, (size_t)(4u * 4), 4, opt.workspace_allocator); - if (top_blob_int32.empty()) - return -100; - - conv1x1s1_sgemm_pack1to8_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); + conv1x1s1_sgemm_pack1to4_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); } else if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) { - top_blob_int32.create(outw, outh, num_output / 4, (size_t)(4u * 4), 4, opt.workspace_allocator); - if (top_blob_int32.empty()) - return -100; - - conv1x1s2_pack1to8_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); + conv1x1s2_pack1to4_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); } else if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - top_blob_int32.create(outw, outh, num_output / 4, (size_t)(4u * 4), 4, opt.workspace_allocator); - if (top_blob_int32.empty()) - return -100; - - conv3x3s1_pack1to8_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); + conv3x3s1_pack1to4_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); } else if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) { - top_blob_int32.create(outw, outh, num_output / 4, (size_t)(4u * 4), 4, opt.workspace_allocator); - if (top_blob_int32.empty()) - return -100; - - conv3x3s2_pack1to8_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); + conv3x3s2_pack1to4_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); } else if (kernel_w == 7 && kernel_h == 7 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) { - top_blob_int32.create(outw, outh, num_output / 4, (size_t)(4u * 4), 4, opt.workspace_allocator); - if (top_blob_int32.empty()) - return -100; - - conv7x7s2_pack1to8_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); + conv7x7s2_pack1to4_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); } else if (opt.use_sgemm_convolution) // TODO better condition && num_input >= 8 && num_output >= 8) { - top_blob_int32.create(outw, outh, num_output / 4, (size_t)(4u * 4), 4, opt.workspace_allocator); - if (top_blob_int32.empty()) - return -100; - - convolution_im2col_sgemm_pack1to8_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); + convolution_im2col_sgemm_pack1to4_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); } else { - top_blob_int32.create(outw, outh, num_output / out_elempack, (size_t)(4u * out_elempack), out_elempack, opt.workspace_allocator); - if (top_blob_int32.empty()) - return -100; - - convolution_pack1to8_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); + convolution_pack1to4_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); } Mat scale_in_data(num_output); @@ -2157,11 +2110,6 @@ int Convolution_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_blob, con if (elempack == 8 && out_elempack == 1) { - Mat top_blob_int32; - top_blob_int32.create(outw, outh, num_output / out_elempack, (size_t)(4u * out_elempack), out_elempack, opt.workspace_allocator); - if (top_blob_int32.empty()) - return -100; - if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { conv1x1s1_sgemm_pack8to1_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); @@ -2214,11 +2162,6 @@ int Convolution_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_blob, con if (elempack == 1 && out_elempack == 1) { - Mat top_blob_int32; - top_blob_int32.create(outw, outh, num_output / out_elempack, (size_t)(4u * out_elempack), out_elempack, opt.workspace_allocator); - if (top_blob_int32.empty()) - return -100; - if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { conv1x1s1_sgemm_int8_neon(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); diff --git a/src/layer/arm/convolution_pack1to8_int8.h b/src/layer/arm/convolution_pack1to4_int8.h similarity index 89% rename from src/layer/arm/convolution_pack1to8_int8.h rename to src/layer/arm/convolution_pack1to4_int8.h index 200bcc53c..9f6390fb2 100644 --- a/src/layer/arm/convolution_pack1to8_int8.h +++ b/src/layer/arm/convolution_pack1to4_int8.h @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -static void convolution_pack1to8_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_int8, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) +static void convolution_pack1to4_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_int8, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) { int w = bottom_blob.w; int channels = bottom_blob.c; @@ -53,7 +53,6 @@ static void convolution_pack1to8_int8_neon(const Mat& bottom_blob, Mat& top_blob for (int j = 0; j < outw; j++) { int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); const signed char* kptr = weight_data_int8.channel(p); @@ -69,17 +68,15 @@ static void convolution_pack1to8_int8_neon(const Mat& bottom_blob, Mat& top_blob int8x8_t _w = vld1_s8(kptr); int16x8_t _s0 = vmull_s8(_val, _w); _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); - _sum1 = vaddw_s16(_sum1, vget_high_s16(_s0)); kptr += 8; } } - vst1q_s32(outptr + j * 8, _sum0); - vst1q_s32(outptr + j * 8 + 4, _sum1); + vst1q_s32(outptr + j * 4, _sum0); } - outptr += outw * 8; + outptr += outw * 4; } } } diff --git a/src/layer/arm/convolution_pack8_int8.h b/src/layer/arm/convolution_pack8to4_int8.h similarity index 67% rename from src/layer/arm/convolution_pack8_int8.h rename to src/layer/arm/convolution_pack8to4_int8.h index b9eebe8f5..d8503a938 100644 --- a/src/layer/arm/convolution_pack8_int8.h +++ b/src/layer/arm/convolution_pack8to4_int8.h @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -static void convolution_pack8_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_int8, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) +static void convolution_pack8to4_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_int8, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) { int w = bottom_blob.w; int channels = bottom_blob.c; @@ -54,8 +54,6 @@ static void convolution_pack8_int8_neon(const Mat& bottom_blob, Mat& top_blob, c { int32x4_t _sum01 = vdupq_n_s32(0); int32x4_t _sum23 = vdupq_n_s32(0); - int32x4_t _sum45 = vdupq_n_s32(0); - int32x4_t _sum67 = vdupq_n_s32(0); const signed char* kptr = weight_data_int8.channel(p); @@ -73,46 +71,30 @@ static void convolution_pack8_int8_neon(const Mat& bottom_blob, Mat& top_blob, c int8x8_t _w1 = vld1_s8(kptr + 8); int8x8_t _w2 = vld1_s8(kptr + 16); int8x8_t _w3 = vld1_s8(kptr + 24); - int8x8_t _w4 = vld1_s8(kptr + 32); - int8x8_t _w5 = vld1_s8(kptr + 40); - int8x8_t _w6 = vld1_s8(kptr + 48); - int8x8_t _w7 = vld1_s8(kptr + 56); int16x8_t _wv0 = vmull_s8(_val, _w0); int16x8_t _wv1 = vmull_s8(_val, _w1); int16x8_t _wv2 = vmull_s8(_val, _w2); int16x8_t _wv3 = vmull_s8(_val, _w3); - int16x8_t _wv4 = vmull_s8(_val, _w4); - int16x8_t _wv5 = vmull_s8(_val, _w5); - int16x8_t _wv6 = vmull_s8(_val, _w6); - int16x8_t _wv7 = vmull_s8(_val, _w7); int16x4_t _wv00 = vpadd_s16(vget_low_s16(_wv0), vget_high_s16(_wv0)); int16x4_t _wv11 = vpadd_s16(vget_low_s16(_wv1), vget_high_s16(_wv1)); int16x4_t _wv22 = vpadd_s16(vget_low_s16(_wv2), vget_high_s16(_wv2)); int16x4_t _wv33 = vpadd_s16(vget_low_s16(_wv3), vget_high_s16(_wv3)); - int16x4_t _wv44 = vpadd_s16(vget_low_s16(_wv4), vget_high_s16(_wv4)); - int16x4_t _wv55 = vpadd_s16(vget_low_s16(_wv5), vget_high_s16(_wv5)); - int16x4_t _wv66 = vpadd_s16(vget_low_s16(_wv6), vget_high_s16(_wv6)); - int16x4_t _wv77 = vpadd_s16(vget_low_s16(_wv7), vget_high_s16(_wv7)); _sum01 = vpadalq_s16(_sum01, vcombine_s16(_wv00, _wv11)); _sum23 = vpadalq_s16(_sum23, vcombine_s16(_wv22, _wv33)); - _sum45 = vpadalq_s16(_sum45, vcombine_s16(_wv44, _wv55)); - _sum67 = vpadalq_s16(_sum67, vcombine_s16(_wv66, _wv77)); - kptr += 64; + kptr += 32; } } int32x4_t _sum0 = vcombine_s32(vpadd_s32(vget_low_s32(_sum01), vget_high_s32(_sum01)), vpadd_s32(vget_low_s32(_sum23), vget_high_s32(_sum23))); - int32x4_t _sum1 = vcombine_s32(vpadd_s32(vget_low_s32(_sum45), vget_high_s32(_sum45)), vpadd_s32(vget_low_s32(_sum67), vget_high_s32(_sum67))); - vst1q_s32(outptr + j * 8, _sum0); - vst1q_s32(outptr + j * 8 + 4, _sum1); + vst1q_s32(outptr + j * 4, _sum0); } - outptr += outw * 8; + outptr += outw * 4; } } } diff --git a/src/layer/arm/convolution_sgemm_pack1to8_int8.h b/src/layer/arm/convolution_sgemm_pack1to4_int8.h similarity index 99% rename from src/layer/arm/convolution_sgemm_pack1to8_int8.h rename to src/layer/arm/convolution_sgemm_pack1to4_int8.h index aec9ec94c..0d459f18c 100644 --- a/src/layer/arm/convolution_sgemm_pack1to8_int8.h +++ b/src/layer/arm/convolution_sgemm_pack1to4_int8.h @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -static void im2col_sgemm_pack1to8_int8_neon(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt) +static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt) { // Mat bottom_im2col(size, maxk, inch, 8u, 8, opt.workspace_allocator); @@ -2422,7 +2422,7 @@ static void im2col_sgemm_pack1to8_int8_neon(const Mat& bottom_im2col, Mat& top_b } } -static void convolution_im2col_sgemm_transform_kernel_pack1to8_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) +static void convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) { const int maxk = kernel_w * kernel_h; @@ -2519,7 +2519,7 @@ static void convolution_im2col_sgemm_transform_kernel_pack1to8_int8_neon(const M } } -static void convolution_im2col_sgemm_pack1to8_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) +static void convolution_im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) { int w = bottom_blob.w; int inch = bottom_blob.c; @@ -2583,5 +2583,5 @@ static void convolution_im2col_sgemm_pack1to8_int8_neon(const Mat& bottom_blob, } } - im2col_sgemm_pack1to8_int8_neon(bottom_im2col, top_blob, kernel, opt); + im2col_sgemm_pack1to4_int8_neon(bottom_im2col, top_blob, kernel, opt); } diff --git a/src/layer/arm/convolution_sgemm_pack8_int8.h b/src/layer/arm/convolution_sgemm_pack8to4_int8.h similarity index 98% rename from src/layer/arm/convolution_sgemm_pack8_int8.h rename to src/layer/arm/convolution_sgemm_pack8to4_int8.h index 73618fd55..6efa56fa3 100644 --- a/src/layer/arm/convolution_sgemm_pack8_int8.h +++ b/src/layer/arm/convolution_sgemm_pack8to4_int8.h @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -static void im2col_sgemm_pack8_int8_neon(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt) +static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt) { // Mat bottom_im2col(size, maxk, inch, 8u, 8, opt.workspace_allocator); @@ -1103,7 +1103,7 @@ static void im2col_sgemm_pack8_int8_neon(const Mat& bottom_im2col, Mat& top_blob } } -static void convolution_im2col_sgemm_transform_kernel_pack8_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) +static void convolution_im2col_sgemm_transform_kernel_pack8to4_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) { const int maxk = kernel_w * kernel_h; @@ -1163,7 +1163,7 @@ static void convolution_im2col_sgemm_transform_kernel_pack8_int8_neon(const Mat& } } -static void convolution_im2col_sgemm_pack8_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) +static void convolution_im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) { int w = bottom_blob.w; int inch = bottom_blob.c; @@ -1234,5 +1234,5 @@ static void convolution_im2col_sgemm_pack8_int8_neon(const Mat& bottom_blob, Mat } } - im2col_sgemm_pack8_int8_neon(bottom_im2col, top_blob, kernel, opt); + im2col_sgemm_pack8to4_int8_neon(bottom_im2col, top_blob, kernel, opt); } diff --git a/src/layer/arm/convolutiondepthwise_arm.cpp b/src/layer/arm/convolutiondepthwise_arm.cpp index 4f13f1ccb..ebc9121e0 100644 --- a/src/layer/arm/convolutiondepthwise_arm.cpp +++ b/src/layer/arm/convolutiondepthwise_arm.cpp @@ -1538,31 +1538,31 @@ int ConvolutionDepthWise_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_ int outw = (w - kernel_extent_w) / stride_w + 1; int outh = (h - kernel_extent_h) / stride_h + 1; - int out_elempack = 1; -#if __ARM_NEON - if (opt.use_packing_layout) + // depth-wise + if (channels * elempack == group && group == num_output) { - out_elempack = num_output % 8 == 0 ? 8 : 1; - } + int out_elempack = 1; +#if __ARM_NEON + if (opt.use_packing_layout) + { + out_elempack = num_output % 8 == 0 ? 8 : 1; + } #endif // __ARM_NEON - bool use_int8_requantize = int8_scale_term > 100; - size_t out_elemsize = use_int8_requantize ? 1u * out_elempack : 4u * out_elempack; + bool use_int8_requantize = int8_scale_term > 100; + size_t out_elemsize = use_int8_requantize ? 1u * out_elempack : 4u * out_elempack; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - if (opt.use_fp16_storage) - { - out_elemsize = use_int8_requantize ? 1u * out_elempack : 2u * out_elempack; - } + if (opt.use_fp16_storage) + { + out_elemsize = use_int8_requantize ? 1u * out_elempack : 2u * out_elempack; + } #endif - if (opt.use_bf16_storage) - out_elemsize = use_int8_requantize ? 1u * out_elempack : 2u * out_elempack; + if (opt.use_bf16_storage) + out_elemsize = use_int8_requantize ? 1u * out_elempack : 2u * out_elempack; - top_blob.create(outw, outh, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); - if (top_blob.empty()) - return -100; + top_blob.create(outw, outh, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; - // depth-wise - if (channels * elempack == group && group == num_output) - { // TODO use fp16 / bf16 out_elemsize = use_int8_requantize ? 1u * out_elempack : 4u * out_elempack; top_blob.create(outw, outh, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); @@ -1934,6 +1934,28 @@ int ConvolutionDepthWise_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_ return 0; } + int out_elempack = 1; +#if __ARM_NEON + if (opt.use_packing_layout) + { + out_elempack = num_output % 4 == 0 ? 4 : 1; + } +#endif // __ARM_NEON + bool use_int8_requantize = int8_scale_term > 100; + size_t out_elemsize = use_int8_requantize ? 1u * out_elempack : 4u * out_elempack; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (opt.use_fp16_storage) + { + out_elemsize = use_int8_requantize ? 1u * out_elempack : 2u * out_elempack; + } +#endif + if (opt.use_bf16_storage) + out_elemsize = use_int8_requantize ? 1u * out_elempack : 2u * out_elempack; + + top_blob.create(outw, outh, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + // group convolution const int channels_g = channels * elempack / group; const int num_output_g = num_output / group; @@ -1944,7 +1966,7 @@ int ConvolutionDepthWise_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_ if (opt.use_packing_layout) { g_elempack = channels_g % 8 == 0 ? 8 : 1; - out_g_elempack = num_output_g % 8 == 0 ? 8 : 1; + out_g_elempack = num_output_g % 4 == 0 ? 4 : 1; } #endif // __ARM_NEON diff --git a/src/layer/arm/padding_pack8_int8.h b/src/layer/arm/padding_pack8_int8.h index 9bfc5d202..7424398de 100644 --- a/src/layer/arm/padding_pack8_int8.h +++ b/src/layer/arm/padding_pack8_int8.h @@ -12,50 +12,328 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -static void padding_constant_pack8_int8_neon(const Mat& src, Mat& dst, int top, int bottom, int left, int right, int8x8_t _v) +static void padding_constant_pack8_int8_neon(const Mat& src, Mat& dst, int top, int bottom, int left, int right, int8x8_t v) { const signed char* ptr = src; signed char* outptr = dst; - // fill top - for (int y = 0; y < top; y++) - { - for (int x = 0; x < dst.w; x++) - { - vst1_s8(outptr, _v); - outptr += 8; - } - } - // fill center - for (int y = 0; y < src.h; y++) - { - for (int x = 0; x < left; x++) - { - vst1_s8(outptr, _v); - outptr += 8; - } - for (int x = 0; x < src.w; x++) - { - int8x8_t _p = vld1_s8(ptr); - vst1_s8(outptr, _p); - ptr += 8; - outptr += 8; - } - for (int x = 0; x < right; x++) - { - vst1_s8(outptr, _v); - outptr += 8; - } - } - // fill bottom - for (int y = 0; y < bottom; y++) - { - for (int x = 0; x < dst.w; x++) - { - vst1_s8(outptr, _v); - outptr += 8; - } - } + int w = src.w; + int h = src.h; + + int top_size = top * dst.w; + int bottom_size = bottom * dst.w; + +#if __aarch64__ + asm volatile( + "mov v0.8b, %10.8b \n" + "mov v0.d[1], v0.d[0] \n" + "mov v1.16b, v0.16b \n" + "mov v2.16b, v0.16b \n" + "mov v3.16b, v0.16b \n" + + // fill top + "lsr w4, %w8, #3 \n" // w4 = nn = top_size >> 3 + "cmp w4, #0 \n" + "beq 1f \n" + + "0: \n" + "st1 {v0.16b, v1.16b, v2.16b, v3.16b}, [%0], #64 \n" + "subs w4, w4, #1 \n" + "bne 0b \n" + + "1: \n" + + // fill top remain + "and w4, %w8, #7 \n" // w4 = remain = top_size & 7 + + "cmp w4, #4 \n" // w4 >= 4 + "blt 2f \n" + "sub w4, w4, #4 \n" + "st1 {v0.16b, v1.16b}, [%0], #32 \n" + "2: \n" + + "cmp w4, #2 \n" // w4 >= 2 + "blt 3f \n" + "sub w4, w4, #2 \n" + "st1 {v0.16b}, [%0], #16 \n" + "3: \n" + + "cmp w4, #0 \n" // w4 > 0 + "beq 4f \n" + "st1 {v0.8b}, [%0], #8 \n" + "4: \n" + + // fill center h loop + "cmp %w5, #0 \n" + "beq 15f \n" + "5: \n" + + // fill left + "mov w4, %w6 \n" // w4 = left + "cmp w4, #0 \n" + "beq 7f \n" + + "6: \n" + "st1 {v0.8b}, [%0], #8 \n" + "subs w4, w4, #1 \n" + "bne 6b \n" + + "7: \n" + + // fill middle + "lsr w4, %w4, #3 \n" // w4 = nn = w >> 3 + "cmp w4, #0 \n" + "beq 9f \n" + + "8: \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v16.16b, v17.16b, v18.16b, v19.16b}, [%1], #64 \n" + "subs w4, w4, #1 \n" + "st1 {v16.16b, v17.16b, v18.16b, v19.16b}, [%0], #64 \n" + "bne 8b \n" + + "9: \n" + + "and w4, %w4, #7 \n" // w4 = remain = w & 7 + + "cmp w4, #4 \n" // w4 >= 4 + "blt 10f \n" + "prfm pldl1keep, [%1, #256] \n" + "ld1 {v16.16b, v17.16b}, [%1], #32 \n" + "sub w4, w4, #4 \n" + "st1 {v16.16b, v17.16b}, [%0], #32 \n" + "10: \n" + + "cmp w4, #2 \n" // w4 >= 2 + "blt 11f \n" + "prfm pldl1keep, [%1, #128] \n" + "ld1 {v16.16b}, [%1], #16 \n" + "sub w4, w4, #2 \n" + "st1 {v16.16b}, [%0], #16 \n" + "11: \n" + + "cmp w4, #0 \n" // w4 > 0 + "beq 12f \n" + "prfm pldl1keep, [%1, #64] \n" + "ld1 {v16.8b}, [%1], #8 \n" + "st1 {v16.8b}, [%0], #8 \n" + "12: \n" + + // fill right + "mov w4, %w7 \n" // w4 = right + "cmp w4, #0 \n" + "beq 14f \n" + + "13: \n" + "subs w4, w4, #1 \n" + "st1 {v0.8b}, [%0], #8 \n" + "bne 13b \n" + "14: \n" + + "subs %w5, %w5, #1 \n" + "bne 5b \n" + + "15: \n" + + // fill bottom + "lsr w4, %w9, #3 \n" // w4 = nn = bottom_size >> 3 + "cmp w4, #0 \n" + "beq 17f \n" + + "16: \n" + "st1 {v0.16b, v1.16b, v2.16b, v3.16b}, [%0], #64 \n" + "subs w4, w4, #1 \n" + "bne 16b \n" + "17: \n" + + // fill bottom remain + "and w4, %w9, #7 \n" // w4 = remain = bottom_size & 7 + + "cmp w4, #4 \n" // w4 >= 4 + "blt 18f \n" + "sub w4, w4, #4 \n" + "st1 {v0.16b, v1.16b}, [%0], #32 \n" + "18: \n" + + "cmp w4, #2 \n" // w4 >= 2 + "blt 19f \n" + "sub w4, w4, #2 \n" + "st1 {v0.16b}, [%0], #16 \n" + "19: \n" + + "cmp w4, #0 \n" // w4 > 0 + "beq 20f \n" + "st1 {v0.8b}, [%0], #8 \n" + "20: \n" + + : "=r"(outptr), // %0 + "=r"(ptr) // %1 + : "0"(outptr), + "1"(ptr), + "r"(w), // %4 + "r"(h), // %5 + "r"(left), // %6 + "r"(right), // %7 + "r"(top_size), // %8 + "r"(bottom_size), // %9 + "w"(v) // %10 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19"); +#else // __aarch64__ + asm volatile( + "vmov d0, %P10 \n" + "vmov d1, d0 \n" + "vmov q1, q0 \n" + "vmov q2, q0 \n" + "vmov q3, q0 \n" + + // fill top + "lsr r4, %8, #3 \n" // r4 = nn = top_size >> 3 + "cmp r4, #0 \n" + "beq 1f \n" + + "0: \n" + "vstm %0!, {d0-d7} \n" + "subs r4, r4, #1 \n" + "bne 0b \n" + + "1: \n" + + // fill top remain + "and r4, %8, #7 \n" // r4 = remain = top_size & 7 + + "cmp r4, #4 \n" // r4 >= 4 + "blt 2f \n" + "sub r4, r4, #4 \n" + "vst1.s8 {d0-d3}, [%0 :128]! \n" + "2: \n" + + "cmp r4, #2 \n" // r4 >= 2 + "blt 3f \n" + "sub r4, r4, #2 \n" + "vst1.s8 {d0-d1}, [%0 :128]! \n" + "3: \n" + + "cmp r4, #0 \n" // r4 > 0 + "beq 4f \n" + "vst1.s8 {d0}, [%0 :64]! \n" + "4: \n" + + // fill center h loop + "cmp %5, #0 \n" + "beq 15f \n" + "5: \n" + + // fill left + "mov r4, %6 \n" // r4 = left + "cmp r4, #0 \n" + "beq 7f \n" + + "6: \n" + "vst1.s8 {d0}, [%0 :64]! \n" + "subs r4, r4, #1 \n" + "bne 6b \n" + + "7: \n" + + // fill middle + "lsr r4, %4, #3 \n" // r4 = nn = w >> 3 + "cmp r4, #0 \n" + "beq 9f \n" + + "8: \n" + "pld [%1, #512] \n" + "vldm %1!, {d16-d23} \n" + "subs r4, r4, #1 \n" + "vstm %0!, {d16-d23} \n" + "bne 8b \n" + + "9: \n" + + "and r4, %4, #7 \n" // r4 = remain = w & 7 + + "cmp r4, #4 \n" // r4 >= 4 + "blt 10f \n" + "pld [%1, #256] \n" + "vld1.s8 {d16-d19}, [%1 :64]! \n" + "sub r4, r4, #4 \n" + "vst1.s8 {d16-d19}, [%0 :64]! \n" + "10: \n" + + "cmp r4, #2 \n" // r4 >= 2 + "blt 11f \n" + "pld [%1, #128] \n" + "vld1.s8 {d16-d17}, [%1 :64]! \n" + "sub r4, r4, #2 \n" + "vst1.s8 {d16-d17}, [%0 :64]! \n" + "11: \n" + + "cmp r4, #0 \n" // r4 > 0 + "beq 12f \n" + "pld [%1, #64] \n" + "vld1.s8 {d16}, [%1 :64]! \n" + "vst1.s8 {d16}, [%0 :64]! \n" + "12: \n" + + // fill right + "mov r4, %7 \n" // r4 = right + "cmp r4, #0 \n" + "beq 14f \n" + + "13: \n" + "subs r4, r4, #1 \n" + "vst1.s8 {d0}, [%0 :64]! \n" + "bne 13b \n" + "14: \n" + + "subs %5, %5, #1 \n" + "bne 5b \n" + + "15: \n" + + // fill bottom + "lsr r4, %9, #3 \n" // r4 = nn = bottom_size >> 3 + "cmp r4, #0 \n" + "beq 17f \n" + + "16: \n" + "vstm %0!, {d0-d7} \n" + "subs r4, r4, #1 \n" + "bne 16b \n" + "17: \n" + + // fill bottom remain + "and r4, %9, #7 \n" // r4 = remain = bottom_size & 7 + + "cmp r4, #4 \n" // r4 >= 4 + "blt 18f \n" + "sub r4, r4, #4 \n" + "vst1.s8 {d0-d3}, [%0 :64]! \n" + "18: \n" + + "cmp r4, #2 \n" // r4 >= 2 + "blt 19f \n" + "sub r4, r4, #2 \n" + "vst1.s8 {d0-d1}, [%0 :64]! \n" + "19: \n" + + "cmp r4, #0 \n" // r4 > 0 + "beq 20f \n" + "vst1.s8 {d0}, [%0 :64]! \n" + "20: \n" + + : "=r"(outptr), // %0 + "=r"(ptr) // %1 + : "0"(outptr), + "1"(ptr), + "r"(w), // %4 + "r"(h), // %5 + "r"(left), // %6 + "r"(right), // %7 + "r"(top_size), // %8 + "r"(bottom_size), // %9 + "w"(v) // %10 + : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11"); +#endif // __aarch64__ } static void padding_replicate_pack8_int8_neon(const Mat& src, Mat& dst, int top, int bottom, int left, int right)