From 76849cede4d68fde66646ecdc1faa070daeb6dac Mon Sep 17 00:00:00 2001 From: nihui Date: Sun, 17 Jul 2022 21:29:13 +0800 Subject: [PATCH] armv8.4 i8mm optimization for convolution gemm int8 (#4034) --- cmake/ncnn_add_layer.cmake | 14 +- src/CMakeLists.txt | 4 +- src/layer/arm/convolution_arm_i8mm.cpp | 69 + src/layer/arm/convolution_sgemm_int8.h | 1669 ++++++++++++----- .../arm/convolution_sgemm_pack1to4_int8.h | 1358 +++++++++++--- .../arm/convolution_sgemm_pack8to1_int8.h | 1179 +++++++++--- .../arm/convolution_sgemm_pack8to4_int8.h | 865 +++++++-- tests/test_convolution.cpp | 3 +- 8 files changed, 4036 insertions(+), 1125 deletions(-) create mode 100644 src/layer/arm/convolution_arm_i8mm.cpp diff --git a/cmake/ncnn_add_layer.cmake b/cmake/ncnn_add_layer.cmake index 731536224..89d61823d 100644 --- a/cmake/ncnn_add_layer.cmake +++ b/cmake/ncnn_add_layer.cmake @@ -239,25 +239,25 @@ macro(ncnn_add_layer class) ncnn_add_arch_opt_source(${class} asimdfhm "-march=armv8.2-a+fp16+fp16fml") endif() if(NCNN_ARM84BF16) - ncnn_add_arch_opt_source(${class} bf16 "-march=armv8.4-a+bf16") + ncnn_add_arch_opt_source(${class} bf16 "-march=armv8.4-a+fp16+dotprod+bf16") endif() if(NCNN_ARM84I8MM) - ncnn_add_arch_opt_source(${class} i8mm "-march=armv8.4-a+i8mm") + ncnn_add_arch_opt_source(${class} i8mm "-march=armv8.4-a+fp16+dotprod+i8mm") endif() if(NCNN_ARM86SVE) - ncnn_add_arch_opt_source(${class} sve "-march=armv8.6-a+sve") + ncnn_add_arch_opt_source(${class} sve "-march=armv8.6-a+fp16+dotprod+sve") endif() if(NCNN_ARM86SVE2) - ncnn_add_arch_opt_source(${class} sve2 "-march=armv8.6-a+sve2") + ncnn_add_arch_opt_source(${class} sve2 "-march=armv8.6-a+fp16+dotprod+sve2") endif() if(NCNN_ARM86SVEBF16) - ncnn_add_arch_opt_source(${class} svebf16 "-march=armv8.6-a+sve+bf16") + ncnn_add_arch_opt_source(${class} svebf16 "-march=armv8.6-a+fp16+dotprod+sve+bf16") endif() if(NCNN_ARM86SVEI8MM) - ncnn_add_arch_opt_source(${class} svei8mm "-march=armv8.6-a+sve+i8mm") + ncnn_add_arch_opt_source(${class} svei8mm "-march=armv8.6-a+fp16+dotprod+sve+i8mm") endif() if(NCNN_ARM86SVEF32MM) - ncnn_add_arch_opt_source(${class} svef32mm "-march=armv8.6-a+sve+f32mm") + ncnn_add_arch_opt_source(${class} svef32mm "-march=armv8.6-a+fp16+dotprod+sve+f32mm") endif() endif() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 428c56197..a5582fc8b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -414,7 +414,7 @@ endif() if(((IOS AND CMAKE_OSX_ARCHITECTURES MATCHES "arm64") OR (APPLE AND CMAKE_OSX_ARCHITECTURES MATCHES "arm64") OR (CMAKE_SYSTEM_PROCESSOR MATCHES "^(arm64|aarch64)"))) if(NOT NCNN_RUNTIME_CPU AND NCNN_ARM86SVE) - set(ARM_MARCH_FLAG "-march=armv8.6-a+sve") + set(ARM_MARCH_FLAG "-march=armv8.6-a+fp16+dotprod+sve") if(NCNN_ARM86SVE2) set(ARM_MARCH_FLAG "${ARM_MARCH_FLAG}+sve2") endif() @@ -428,7 +428,7 @@ if(((IOS AND CMAKE_OSX_ARCHITECTURES MATCHES "arm64") OR (APPLE AND CMAKE_OSX_AR set(ARM_MARCH_FLAG "${ARM_MARCH_FLAG}+f32mm") endif() elseif(NOT NCNN_RUNTIME_CPU AND (NCNN_ARM84BF16 OR NCNN_ARM84I8MM)) - set(ARM_MARCH_FLAG "-march=armv8.4-a") + set(ARM_MARCH_FLAG "-march=armv8.4-a+fp16+dotprod") if(NCNN_ARM84BF16) set(ARM_MARCH_FLAG "${ARM_MARCH_FLAG}+bf16") endif() diff --git a/src/layer/arm/convolution_arm_i8mm.cpp b/src/layer/arm/convolution_arm_i8mm.cpp new file mode 100644 index 000000000..adbb31177 --- /dev/null +++ b/src/layer/arm/convolution_arm_i8mm.cpp @@ -0,0 +1,69 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" + +namespace ncnn { + +#include "convolution_sgemm_int8.h" +#include "convolution_sgemm_pack1to4_int8.h" +#include "convolution_sgemm_pack8to1_int8.h" +#include "convolution_sgemm_pack8to4_int8.h" + +// pack1 +void im2col_sgemm_int8_neon_i8mm(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt) +{ + im2col_sgemm_int8_neon(bottom_im2col, top_blob, kernel, opt); +} + +void convolution_im2col_sgemm_transform_kernel_int8_neon_i8mm(const Mat& kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) +{ + convolution_im2col_sgemm_transform_kernel_int8_neon(kernel, kernel_tm, inch, outch, kernel_w, kernel_h); +} + +// pack1to4 +void im2col_sgemm_pack1to4_int8_neon_i8mm(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt) +{ + im2col_sgemm_pack1to4_int8_neon(bottom_im2col, top_blob, kernel, opt); +} + +void convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon_i8mm(const Mat& kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) +{ + convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon(kernel, kernel_tm, inch, outch, kernel_w, kernel_h); +} + +// pack8to1 +void im2col_sgemm_pack8to1_int8_neon_i8mm(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt) +{ + im2col_sgemm_pack8to1_int8_neon(bottom_im2col, top_blob, kernel, opt); +} + +void convolution_im2col_sgemm_transform_kernel_pack8to1_int8_neon_i8mm(const Mat& kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) +{ + convolution_im2col_sgemm_transform_kernel_pack8to1_int8_neon(kernel, kernel_tm, inch, outch, kernel_w, kernel_h); +} + +// pack8to4 +void im2col_sgemm_pack8to4_int8_neon_i8mm(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt) +{ + im2col_sgemm_pack8to4_int8_neon(bottom_im2col, top_blob, kernel, opt); +} + +void convolution_im2col_sgemm_transform_kernel_pack8to4_int8_neon_i8mm(const Mat& kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) +{ + convolution_im2col_sgemm_transform_kernel_pack8to4_int8_neon(kernel, kernel_tm, inch, outch, kernel_w, kernel_h); +} + +} // namespace ncnn diff --git a/src/layer/arm/convolution_sgemm_int8.h b/src/layer/arm/convolution_sgemm_int8.h index 85317f88b..f9e412394 100644 --- a/src/layer/arm/convolution_sgemm_int8.h +++ b/src/layer/arm/convolution_sgemm_int8.h @@ -12,19 +12,36 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __ARM_NEON && __aarch64__ && !__ARM_FEATURE_DOTPROD +#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 +void im2col_sgemm_int8_neon_i8mm(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt); +void convolution_im2col_sgemm_transform_kernel_int8_neon_i8mm(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD void im2col_sgemm_int8_neon_asimddp(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt); void convolution_im2col_sgemm_transform_kernel_int8_neon_asimddp(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h); #endif +#endif static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt) { -#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __ARM_NEON && __aarch64__ && !__ARM_FEATURE_DOTPROD +#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + im2col_sgemm_int8_neon_i8mm(bottom_im2col, top_blob, kernel, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD if (ncnn::cpu_support_arm_asimddp()) { im2col_sgemm_int8_neon_asimddp(bottom_im2col, top_blob, kernel, opt); return; } +#endif #endif // Mat bottom_im2col(size, maxk, inch, 8u, 8, opt.workspace_allocator); @@ -42,9 +59,7 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons #if __ARM_FEATURE_DOTPROD if (inch >= 8) { - if (size >= 16) - tmp.create(16 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, size / 16 + (size % 16) / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator); - else if (size >= 8) + if (size >= 8) tmp.create(8 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, size / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator); else if (size >= 4) tmp.create(4 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, size / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator); @@ -55,9 +70,7 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons } else if (inch >= 4) { - if (size >= 16) - tmp.create(16 * maxk, inch / 4 + inch % 4, size / 16 + (size % 16) / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 4u, 4, opt.workspace_allocator); - else if (size >= 8) + if (size >= 8) tmp.create(8 * maxk, inch / 4 + inch % 4, size / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 4u, 4, opt.workspace_allocator); else if (size >= 4) tmp.create(4 * maxk, inch / 4 + inch % 4, size / 4 + (size % 4) / 2 + size % 2, 4u, 4, opt.workspace_allocator); @@ -68,9 +81,7 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons } else { - if (size >= 16) - tmp.create(16 * maxk, inch, size / 16 + (size % 16) / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 1u, 1, opt.workspace_allocator); - else if (size >= 8) + if (size >= 8) tmp.create(8 * maxk, inch, size / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 1u, 1, opt.workspace_allocator); else if (size >= 4) tmp.create(4 * maxk, inch, size / 4 + (size % 4) / 2 + size % 2, 1u, 1, opt.workspace_allocator); @@ -140,17 +151,17 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons } #endif // __ARM_NEON { -#if __ARM_NEON && __aarch64__ +#if __aarch64__ #if __ARM_FEATURE_DOTPROD - int nn_size = size >> 4; + int nn_size = size >> 3; int remain_size_start = 0; #pragma omp parallel for num_threads(opt.num_threads) for (int ii = 0; ii < nn_size; ii++) { - int i = remain_size_start + ii * 16; + int i = ii * 8; - signed char* tmpptr = tmp.channel(i / 16); + signed char* tmpptr = tmp.channel(i / 8); int q = 0; for (; q + 7 < inch; q += 8) @@ -166,17 +177,26 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons for (int k = 0; k < maxk; k++) { +#if __ARM_FEATURE_MATMUL_INT8 asm volatile( - "ld1 {v0.16b}, [%0] \n" - "ld1 {v1.16b}, [%1] \n" - "ld1 {v2.16b}, [%2] \n" - "ld1 {v3.16b}, [%3] \n" - "ld1 {v4.16b}, [%4] \n" - "ld1 {v5.16b}, [%5] \n" - "ld1 {v6.16b}, [%6] \n" - "ld1 {v7.16b}, [%7] \n" - "st4 {v0.16b, v1.16b, v2.16b, v3.16b}, [%8], #64 \n" - "st4 {v4.16b, v5.16b, v6.16b, v7.16b}, [%8], #64 \n" + "ld1 {v0.8b}, [%0] \n" + "ld1 {v1.8b}, [%1] \n" + "ld1 {v2.8b}, [%2] \n" + "ld1 {v3.8b}, [%3] \n" + "ld1 {v4.8b}, [%4] \n" + "ld1 {v5.8b}, [%5] \n" + "ld1 {v6.8b}, [%6] \n" + "ld1 {v7.8b}, [%7] \n" + "zip1 v8.8b, v0.8b, v4.8b \n" + "zip1 v9.8b, v1.8b, v5.8b \n" + "zip1 v10.8b, v2.8b, v6.8b \n" + "zip1 v11.8b, v3.8b, v7.8b \n" + "zip2 v0.8b, v0.8b, v4.8b \n" + "zip2 v1.8b, v1.8b, v5.8b \n" + "zip2 v2.8b, v2.8b, v6.8b \n" + "zip2 v3.8b, v3.8b, v7.8b \n" + "st4 {v8.8b, v9.8b, v10.8b, v11.8b}, [%8], #32 \n" + "st4 {v0.8b, v1.8b, v2.8b, v3.8b}, [%8], #32 \n" : "=r"(img0), // %0 "=r"(img1), "=r"(img2), @@ -195,93 +215,8 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons "6"(img6), "7"(img7), "8"(tmpptr) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); - img0 += size; - img1 += size; - img2 += size; - img3 += size; - img4 += size; - img5 += size; - img6 += size; - img7 += size; - } - } - for (; q + 3 < inch; q += 4) - { - const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; - const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i; - const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i; - const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i; - - for (int k = 0; k < maxk; k++) - { - asm volatile( - "ld1 {v0.16b}, [%0] \n" - "ld1 {v1.16b}, [%1] \n" - "ld1 {v2.16b}, [%2] \n" - "ld1 {v3.16b}, [%3] \n" - "st4 {v0.16b, v1.16b, v2.16b, v3.16b}, [%4], #64 \n" - : "=r"(img0), // %0 - "=r"(img1), - "=r"(img2), - "=r"(img3), - "=r"(tmpptr) // %4 - : "0"(img0), - "1"(img1), - "2"(img2), - "3"(img3), - "4"(tmpptr) - : "memory", "v0", "v1", "v2", "v3"); - img0 += size; - img1 += size; - img2 += size; - img3 += size; - } - } - for (; q < inch; q++) - { - const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; - - for (int k = 0; k < maxk; k++) - { - asm volatile( - "prfm pldl1keep, [%0, #128] \n" - "ld1 {v0.16b}, [%0] \n" - "st1 {v0.16b}, [%1], #16 \n" - : "=r"(img0), // %0 - "=r"(tmpptr) // %1 - : "0"(img0), - "1"(tmpptr) - : "memory", "v0"); - img0 += size; - } - } - } - - remain_size_start += nn_size << 4; - nn_size = (size - remain_size_start) >> 3; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int ii = 0; ii < nn_size; ii++) - { - int i = remain_size_start + ii * 8; - - signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8); - - int q = 0; - for (; q + 7 < inch; q += 8) - { - const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; - const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i; - const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i; - const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i; - const signed char* img4 = (const signed char*)bottom_im2col.channel(q + 4) + i; - const signed char* img5 = (const signed char*)bottom_im2col.channel(q + 5) + i; - const signed char* img6 = (const signed char*)bottom_im2col.channel(q + 6) + i; - const signed char* img7 = (const signed char*)bottom_im2col.channel(q + 7) + i; - - for (int k = 0; k < maxk; k++) - { + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11"); +#else // __ARM_FEATURE_MATMUL_INT8 asm volatile( "ld1 {v0.8b}, [%0] \n" "ld1 {v1.8b}, [%1] \n" @@ -312,6 +247,7 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons "7"(img7), "8"(tmpptr) : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); +#endif // __ARM_FEATURE_MATMUL_INT8 img0 += size; img1 += size; img2 += size; @@ -387,7 +323,7 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons int i = remain_size_start + ii * 4; #if __ARM_FEATURE_DOTPROD - signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4); + signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4); #else signed char* tmpptr = tmp.channel(i / 4); #endif @@ -406,7 +342,47 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons for (int k = 0; k < maxk; k++) { -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + tmpptr[0] = img0[0]; + tmpptr[1] = img1[0]; + tmpptr[2] = img2[0]; + tmpptr[3] = img3[0]; + tmpptr[4] = img4[0]; + tmpptr[5] = img5[0]; + tmpptr[6] = img6[0]; + tmpptr[7] = img7[0]; + tmpptr += 8; + + tmpptr[0] = img0[1]; + tmpptr[1] = img1[1]; + tmpptr[2] = img2[1]; + tmpptr[3] = img3[1]; + tmpptr[4] = img4[1]; + tmpptr[5] = img5[1]; + tmpptr[6] = img6[1]; + tmpptr[7] = img7[1]; + tmpptr += 8; + + tmpptr[0] = img0[2]; + tmpptr[1] = img1[2]; + tmpptr[2] = img2[2]; + tmpptr[3] = img3[2]; + tmpptr[4] = img4[2]; + tmpptr[5] = img5[2]; + tmpptr[6] = img6[2]; + tmpptr[7] = img7[2]; + tmpptr += 8; + + tmpptr[0] = img0[3]; + tmpptr[1] = img1[3]; + tmpptr[2] = img2[3]; + tmpptr[3] = img3[3]; + tmpptr[4] = img4[3]; + tmpptr[5] = img5[3]; + tmpptr[6] = img6[3]; + tmpptr[7] = img7[3]; + tmpptr += 8; +#elif __ARM_FEATURE_DOTPROD tmpptr[0] = img0[0]; tmpptr[1] = img1[0]; tmpptr[2] = img2[0]; @@ -486,7 +462,7 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons tmpptr[6] = img6[3]; tmpptr[7] = img7[3]; tmpptr += 8; -#endif // __ARM_FEATURE_DOTPROD +#endif img0 += size; img1 += size; @@ -553,19 +529,19 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons remain_size_start += nn_size << 2; nn_size = (size - remain_size_start) >> 1; -#else // __ARM_NEON && __aarch64__ +#else // __aarch64__ int remain_size_start = 0; int nn_size = (size - remain_size_start) >> 1; -#endif // __ARM_NEON && __aarch64__ +#endif // __aarch64__ #pragma omp parallel for num_threads(opt.num_threads) for (int ii = 0; ii < nn_size; ii++) { int i = remain_size_start + ii * 2; -#if __ARM_NEON && __aarch64__ +#if __aarch64__ #if __ARM_FEATURE_DOTPROD - signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2); + signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2); #else signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); #endif @@ -588,7 +564,27 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons for (int k = 0; k < maxk; k++) { -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + tmpptr[0] = img0[0]; + tmpptr[1] = img1[0]; + tmpptr[2] = img2[0]; + tmpptr[3] = img3[0]; + tmpptr[4] = img4[0]; + tmpptr[5] = img5[0]; + tmpptr[6] = img6[0]; + tmpptr[7] = img7[0]; + tmpptr += 8; + + tmpptr[0] = img0[1]; + tmpptr[1] = img1[1]; + tmpptr[2] = img2[1]; + tmpptr[3] = img3[1]; + tmpptr[4] = img4[1]; + tmpptr[5] = img5[1]; + tmpptr[6] = img6[1]; + tmpptr[7] = img7[1]; + tmpptr += 8; +#elif __ARM_FEATURE_DOTPROD tmpptr[0] = img0[0]; tmpptr[1] = img1[0]; tmpptr[2] = img2[0]; @@ -628,7 +624,7 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons tmpptr[6] = img6[1]; tmpptr[7] = img7[1]; tmpptr += 8; -#endif // __ARM_FEATURE_DOTPROD +#endif img0 += size; img1 += size; @@ -687,9 +683,9 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons #pragma omp parallel for num_threads(opt.num_threads) for (int i = remain_size_start; i < size; i++) { -#if __ARM_NEON && __aarch64__ +#if __aarch64__ #if __ARM_FEATURE_DOTPROD - signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); #else signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); #endif @@ -774,25 +770,29 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons int remain_outch_start = 0; #if __ARM_NEON - nn_outch = outch >> 2; +#if __ARM_FEATURE_DOTPROD + nn_outch = outch / 8; + remain_outch_start = nn_outch * 8; #pragma omp parallel for num_threads(opt.num_threads) for (int pp = 0; pp < nn_outch; pp++) { - int p = pp * 4; + int p = pp * 8; int* outptr0 = top_blob.channel(p); int* outptr1 = top_blob.channel(p + 1); int* outptr2 = top_blob.channel(p + 2); int* outptr3 = top_blob.channel(p + 3); + int* outptr4 = top_blob.channel(p + 4); + int* outptr5 = top_blob.channel(p + 5); + int* outptr6 = top_blob.channel(p + 6); + int* outptr7 = top_blob.channel(p + 7); int i = 0; -#if __aarch64__ -#if __ARM_FEATURE_DOTPROD - for (; i + 15 < size; i += 16) + for (; i + 7 < size; i += 8) { - const signed char* tmpptr = tmp.channel(i / 16); - const signed char* kptr0 = kernel.channel(p / 4); + const signed char* tmpptr = tmp.channel(i / 8); + const signed char* kptr0 = kernel.channel(p / 8); int nn = (inch / 8) * maxk; int nn4 = ((inch % 8) / 4) * maxk; @@ -816,95 +816,118 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons "eor v30.16b, v30.16b, v30.16b \n" "eor v31.16b, v31.16b, v31.16b \n" - "cmp %w4, #0 \n" + "cmp %w8, #0 \n" "beq 1f \n" - "ld1 {v8.16b}, [%8], #16 \n" // _w0123_l +#if __ARM_FEATURE_MATMUL_INT8 + "eor v4.16b, v4.16b, v4.16b \n" + "eor v5.16b, v5.16b, v5.16b \n" + "eor v6.16b, v6.16b, v6.16b \n" + "eor v7.16b, v7.16b, v7.16b \n" + "eor v12.16b, v12.16b, v12.16b \n" + "eor v13.16b, v13.16b, v13.16b \n" + "eor v14.16b, v14.16b, v14.16b \n" + "eor v15.16b, v15.16b, v15.16b \n" + + "0: \n" + + "ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [%11], #64 \n" // _val0 _val1 _val2 _val3 + "ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [%12], #64 \n" // _w01 _w23 _w45 _w67 + + "smmla v4.4s, v0.16b, v8.16b \n" + "smmla v17.4s, v0.16b, v9.16b \n" + "smmla v5.4s, v1.16b, v8.16b \n" + "smmla v19.4s, v1.16b, v9.16b \n" + "smmla v6.4s, v2.16b, v8.16b \n" + "smmla v21.4s, v2.16b, v9.16b \n" + "smmla v7.4s, v3.16b, v8.16b \n" + "smmla v23.4s, v3.16b, v9.16b \n" + + "subs %w8, %w8, #1 \n" + + "smmla v12.4s, v0.16b, v10.16b \n" + "smmla v25.4s, v0.16b, v11.16b \n" + "smmla v13.4s, v1.16b, v10.16b \n" + "smmla v27.4s, v1.16b, v11.16b \n" + "smmla v14.4s, v2.16b, v10.16b \n" + "smmla v29.4s, v2.16b, v11.16b \n" + "smmla v15.4s, v3.16b, v10.16b \n" + "smmla v31.4s, v3.16b, v11.16b \n" - "ld1 {v0.16b}, [%7], #16 \n" // _val0123_l + "bne 0b \n" + "trn1 v16.2d, v4.2d, v17.2d \n" + "trn2 v17.2d, v4.2d, v17.2d \n" + "trn1 v18.2d, v5.2d, v19.2d \n" + "trn2 v19.2d, v5.2d, v19.2d \n" + "trn1 v20.2d, v6.2d, v21.2d \n" + "trn2 v21.2d, v6.2d, v21.2d \n" + "trn1 v22.2d, v7.2d, v23.2d \n" + "trn2 v23.2d, v7.2d, v23.2d \n" + + "trn1 v24.2d, v12.2d, v25.2d \n" + "trn2 v25.2d, v12.2d, v25.2d \n" + "trn1 v26.2d, v13.2d, v27.2d \n" + "trn2 v27.2d, v13.2d, v27.2d \n" + "trn1 v28.2d, v14.2d, v29.2d \n" + "trn2 v29.2d, v14.2d, v29.2d \n" + "trn1 v30.2d, v15.2d, v31.2d \n" + "trn2 v31.2d, v15.2d, v31.2d \n" +#else // __ARM_FEATURE_MATMUL_INT8 "0: \n" - "ld1 {v1.16b}, [%7], #16 \n" // _val4567_l + "ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [%11], #64 \n" // _val0123_l _val4567_l _val0123_h _val4567_h + "ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [%12], #64 \n" // _w0123_l _w0123_h _w4567_l _w4567_h "sdot v16.4s, v8.16b, v0.4b[0] \n" "sdot v17.4s, v8.16b, v0.4b[1] \n" "sdot v18.4s, v8.16b, v0.4b[2] \n" "sdot v19.4s, v8.16b, v0.4b[3] \n" - - "ld1 {v2.16b}, [%7], #16 \n" // _val891011_l - "sdot v20.4s, v8.16b, v1.4b[0] \n" "sdot v21.4s, v8.16b, v1.4b[1] \n" "sdot v22.4s, v8.16b, v1.4b[2] \n" "sdot v23.4s, v8.16b, v1.4b[3] \n" - "ld1 {v3.16b}, [%7], #16 \n" // _val12131415_l - - "sdot v24.4s, v8.16b, v2.4b[0] \n" - "sdot v25.4s, v8.16b, v2.4b[1] \n" - - "ld1 {v9.16b}, [%8], #16 \n" // _w0123_h - - "sdot v26.4s, v8.16b, v2.4b[2] \n" - "sdot v27.4s, v8.16b, v2.4b[3] \n" - - "ld1 {v4.16b}, [%7], #16 \n" // _val0123_h - - "sdot v28.4s, v8.16b, v3.4b[0] \n" - "sdot v29.4s, v8.16b, v3.4b[1] \n" - "sdot v30.4s, v8.16b, v3.4b[2] \n" - "sdot v31.4s, v8.16b, v3.4b[3] \n" - - "ld1 {v5.16b}, [%7], #16 \n" // _val4567_h - - "sdot v16.4s, v9.16b, v4.4b[0] \n" - "sdot v17.4s, v9.16b, v4.4b[1] \n" - "sdot v18.4s, v9.16b, v4.4b[2] \n" - "sdot v19.4s, v9.16b, v4.4b[3] \n" - - "ld1 {v6.16b}, [%7], #16 \n" // _val891011_h - - "sdot v20.4s, v9.16b, v5.4b[0] \n" - "sdot v21.4s, v9.16b, v5.4b[1] \n" - "sdot v22.4s, v9.16b, v5.4b[2] \n" - "sdot v23.4s, v9.16b, v5.4b[3] \n" - - "ld1 {v7.16b}, [%7], #16 \n" // _val12131415_h - - "sdot v24.4s, v9.16b, v6.4b[0] \n" - "sdot v25.4s, v9.16b, v6.4b[1] \n" - - "ld1 {v8.16b}, [%8], #16 \n" // _w0123_l - - "sdot v26.4s, v9.16b, v6.4b[2] \n" - "sdot v27.4s, v9.16b, v6.4b[3] \n" - - "ld1 {v0.16b}, [%7], #16 \n" // _val0123_l - - "sdot v28.4s, v9.16b, v7.4b[0] \n" - "sdot v29.4s, v9.16b, v7.4b[1] \n" - - "subs %w4, %w4, #1 \n" - - "sdot v30.4s, v9.16b, v7.4b[2] \n" - "sdot v31.4s, v9.16b, v7.4b[3] \n" + "sdot v16.4s, v9.16b, v2.4b[0] \n" + "sdot v17.4s, v9.16b, v2.4b[1] \n" + "sdot v18.4s, v9.16b, v2.4b[2] \n" + "sdot v19.4s, v9.16b, v2.4b[3] \n" + "sdot v20.4s, v9.16b, v3.4b[0] \n" + "sdot v21.4s, v9.16b, v3.4b[1] \n" + "sdot v22.4s, v9.16b, v3.4b[2] \n" + "sdot v23.4s, v9.16b, v3.4b[3] \n" + + "subs %w8, %w8, #1 \n" + + "sdot v24.4s, v10.16b, v0.4b[0] \n" + "sdot v25.4s, v10.16b, v0.4b[1] \n" + "sdot v26.4s, v10.16b, v0.4b[2] \n" + "sdot v27.4s, v10.16b, v0.4b[3] \n" + "sdot v28.4s, v10.16b, v1.4b[0] \n" + "sdot v29.4s, v10.16b, v1.4b[1] \n" + "sdot v30.4s, v10.16b, v1.4b[2] \n" + "sdot v31.4s, v10.16b, v1.4b[3] \n" + + "sdot v24.4s, v11.16b, v2.4b[0] \n" + "sdot v25.4s, v11.16b, v2.4b[1] \n" + "sdot v26.4s, v11.16b, v2.4b[2] \n" + "sdot v27.4s, v11.16b, v2.4b[3] \n" + "sdot v28.4s, v11.16b, v3.4b[0] \n" + "sdot v29.4s, v11.16b, v3.4b[1] \n" + "sdot v30.4s, v11.16b, v3.4b[2] \n" + "sdot v31.4s, v11.16b, v3.4b[3] \n" "bne 0b \n" - - "sub %7, %7, #16 \n" - "sub %8, %8, #16 \n" - +#endif // __ARM_FEATURE_MATMUL_INT8 "1: \n" - "cmp %w5, #0 \n" + "cmp %w9, #0 \n" "beq 3f \n" "2: \n" - "ld1 {v8.16b}, [%8], #16 \n" - - "ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [%7], #64 \n" + "ld1 {v0.16b, v1.16b}, [%11], #32 \n" // _val0123 _val4567 + "ld1 {v8.16b, v9.16b}, [%12], #32 \n" // _w0 _w1 "sdot v16.4s, v8.16b, v0.4b[0] \n" "sdot v17.4s, v8.16b, v0.4b[1] \n" @@ -914,157 +937,158 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons "sdot v21.4s, v8.16b, v1.4b[1] \n" "sdot v22.4s, v8.16b, v1.4b[2] \n" "sdot v23.4s, v8.16b, v1.4b[3] \n" - "sdot v24.4s, v8.16b, v2.4b[0] \n" - "sdot v25.4s, v8.16b, v2.4b[1] \n" - "sdot v26.4s, v8.16b, v2.4b[2] \n" - "sdot v27.4s, v8.16b, v2.4b[3] \n" - "sdot v28.4s, v8.16b, v3.4b[0] \n" - "sdot v29.4s, v8.16b, v3.4b[1] \n" - "subs %w5, %w5, #1 \n" + "subs %w9, %w9, #1 \n" - "sdot v30.4s, v8.16b, v3.4b[2] \n" - "sdot v31.4s, v8.16b, v3.4b[3] \n" + "sdot v24.4s, v9.16b, v0.4b[0] \n" + "sdot v25.4s, v9.16b, v0.4b[1] \n" + "sdot v26.4s, v9.16b, v0.4b[2] \n" + "sdot v27.4s, v9.16b, v0.4b[3] \n" + "sdot v28.4s, v9.16b, v1.4b[0] \n" + "sdot v29.4s, v9.16b, v1.4b[1] \n" + "sdot v30.4s, v9.16b, v1.4b[2] \n" + "sdot v31.4s, v9.16b, v1.4b[3] \n" "bne 2b \n" "3: \n" - "lsr w4, %w6, #2 \n" // w4 = nn1 >> 2 + "lsr w4, %w10, #2 \n" // w4 = nn1 >> 2 "cmp w4, #0 \n" "beq 5f \n" "4: \n" - "ld1 {v8.8b, v9.8b}, [%8], #16 \n" - - "ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [%7], #64 \n" - - "uzp1 v10.8b, v8.8b, v9.8b \n" - "uzp2 v11.8b, v8.8b, v9.8b \n" - - "uzp1 v4.16b, v0.16b, v1.16b \n" - "uzp2 v5.16b, v0.16b, v1.16b \n" - "uzp1 v6.16b, v2.16b, v3.16b \n" - "uzp2 v7.16b, v2.16b, v3.16b \n" - - "uzp1 v8.8b, v10.8b, v11.8b \n" - "uzp2 v9.8b, v10.8b, v11.8b \n" + "ld2 {v0.4s, v1.4s}, [%11], #32 \n" + "ld2 {v8.4s, v9.4s}, [%12], #32 \n" + + "uzp1 v2.16b, v0.16b, v1.16b \n" + "uzp2 v3.16b, v0.16b, v1.16b \n" + "uzp1 v0.16b, v2.16b, v3.16b \n" + "uzp2 v1.16b, v2.16b, v3.16b \n" + "uzp1 v2.4s, v0.4s, v1.4s \n" // _val0123 + "uzp2 v3.4s, v0.4s, v1.4s \n" // _val4567 + + "uzp1 v10.16b, v8.16b, v9.16b \n" + "uzp2 v11.16b, v8.16b, v9.16b \n" + "uzp1 v8.16b, v10.16b, v11.16b \n" + "uzp2 v9.16b, v10.16b, v11.16b \n" + "uzp1 v10.4s, v8.4s, v9.4s \n" // _w0123f + "uzp2 v11.4s, v8.4s, v9.4s \n" // _w4567f + + "sdot v16.4s, v10.16b, v2.4b[0] \n" + "sdot v17.4s, v10.16b, v2.4b[1] \n" + "sdot v18.4s, v10.16b, v2.4b[2] \n" + "sdot v19.4s, v10.16b, v2.4b[3] \n" + "sdot v20.4s, v10.16b, v3.4b[0] \n" + "sdot v21.4s, v10.16b, v3.4b[1] \n" + "sdot v22.4s, v10.16b, v3.4b[2] \n" + "sdot v23.4s, v10.16b, v3.4b[3] \n" - "uzp1 v0.16b, v4.16b, v5.16b \n" // 0 1 4 5 - "uzp2 v1.16b, v4.16b, v5.16b \n" // 8 9 c d - - "mov v8.d[1], v9.d[0] \n" // _w - - "uzp1 v2.16b, v6.16b, v7.16b \n" // 2 3 6 7 - "uzp2 v3.16b, v6.16b, v7.16b \n" // a b e f + "subs w4, w4, #1 \n" - "sdot v16.4s, v8.16b, v0.4b[0] \n" - "sdot v17.4s, v8.16b, v0.4b[1] \n" - "sdot v18.4s, v8.16b, v2.4b[0] \n" - "sdot v19.4s, v8.16b, v2.4b[1] \n" - "sdot v20.4s, v8.16b, v0.4b[2] \n" - "sdot v21.4s, v8.16b, v0.4b[3] \n" - "sdot v22.4s, v8.16b, v2.4b[2] \n" - "sdot v23.4s, v8.16b, v2.4b[3] \n" - "sdot v24.4s, v8.16b, v1.4b[0] \n" - "sdot v25.4s, v8.16b, v1.4b[1] \n" - "sdot v26.4s, v8.16b, v3.4b[0] \n" - "sdot v27.4s, v8.16b, v3.4b[1] \n" - "sdot v28.4s, v8.16b, v1.4b[2] \n" - "sdot v29.4s, v8.16b, v1.4b[3] \n" - "sdot v30.4s, v8.16b, v3.4b[2] \n" - "sdot v31.4s, v8.16b, v3.4b[3] \n" + "sdot v24.4s, v11.16b, v2.4b[0] \n" + "sdot v25.4s, v11.16b, v2.4b[1] \n" + "sdot v26.4s, v11.16b, v2.4b[2] \n" + "sdot v27.4s, v11.16b, v2.4b[3] \n" + "sdot v28.4s, v11.16b, v3.4b[0] \n" + "sdot v29.4s, v11.16b, v3.4b[1] \n" + "sdot v30.4s, v11.16b, v3.4b[2] \n" + "sdot v31.4s, v11.16b, v3.4b[3] \n" - "subs w4, w4, #1 \n" "bne 4b \n" "5: \n" - "and w4, %w6, #3 \n" // w4 = remain = nn1 & 3 - "cmp w4, #0 \n" // w4 > 0 + "and w4, %w10, #3 \n" // w4 = remain = nn1 & 3 + "cmp w4, #0 \n" // w4 > 0 "beq 7f \n" "6: \n" - "ld1 {v1.8b}, [%8] \n" - "ld1 {v0.16b}, [%7] \n" + "ld1 {v0.8b}, [%11], #8 \n" + "ld1 {v1.8b}, [%12], #8 \n" + + "sshll v0.8h, v0.8b, #0 \n" "sshll v1.8h, v1.8b, #0 \n" - "sshll v2.8h, v0.8b, #0 \n" - "sshll2 v3.8h, v0.16b, #0 \n" - - "smlal v16.4s, v1.4h, v2.h[0] \n" - "smlal v17.4s, v1.4h, v2.h[1] \n" - "smlal v18.4s, v1.4h, v2.h[2] \n" - "smlal v19.4s, v1.4h, v2.h[3] \n" - "smlal v20.4s, v1.4h, v2.h[4] \n" - "smlal v21.4s, v1.4h, v2.h[5] \n" - "smlal v22.4s, v1.4h, v2.h[6] \n" - "smlal v23.4s, v1.4h, v2.h[7] \n" - "smlal v24.4s, v1.4h, v3.h[0] \n" - "smlal v25.4s, v1.4h, v3.h[1] \n" - "smlal v26.4s, v1.4h, v3.h[2] \n" - "smlal v27.4s, v1.4h, v3.h[3] \n" - "smlal v28.4s, v1.4h, v3.h[4] \n" - "smlal v29.4s, v1.4h, v3.h[5] \n" - "smlal v30.4s, v1.4h, v3.h[6] \n" - "smlal v31.4s, v1.4h, v3.h[7] \n" - - "add %7, %7, #16 \n" - "add %8, %8, #4 \n" + + "smlal v16.4s, v1.4h, v0.h[0] \n" + "smlal v17.4s, v1.4h, v0.h[1] \n" + "smlal v18.4s, v1.4h, v0.h[2] \n" + "smlal v19.4s, v1.4h, v0.h[3] \n" + "smlal v20.4s, v1.4h, v0.h[4] \n" + "smlal v21.4s, v1.4h, v0.h[5] \n" + "smlal v22.4s, v1.4h, v0.h[6] \n" + "smlal v23.4s, v1.4h, v0.h[7] \n" "subs w4, w4, #1 \n" + + "smlal2 v24.4s, v1.8h, v0.h[0] \n" + "smlal2 v25.4s, v1.8h, v0.h[1] \n" + "smlal2 v26.4s, v1.8h, v0.h[2] \n" + "smlal2 v27.4s, v1.8h, v0.h[3] \n" + "smlal2 v28.4s, v1.8h, v0.h[4] \n" + "smlal2 v29.4s, v1.8h, v0.h[5] \n" + "smlal2 v30.4s, v1.8h, v0.h[6] \n" + "smlal2 v31.4s, v1.8h, v0.h[7] \n" + "bne 6b \n" "7: \n" - // transpose 4x16 - "trn1 v0.4s, v16.4s, v17.4s \n" - "trn2 v1.4s, v16.4s, v17.4s \n" - "trn1 v2.4s, v18.4s, v19.4s \n" - "trn2 v3.4s, v18.4s, v19.4s \n" - "trn1 v4.4s, v20.4s, v21.4s \n" - "trn2 v5.4s, v20.4s, v21.4s \n" - "trn1 v6.4s, v22.4s, v23.4s \n" - "trn2 v7.4s, v22.4s, v23.4s \n" - "trn1 v8.4s, v24.4s, v25.4s \n" - "trn2 v9.4s, v24.4s, v25.4s \n" - "trn1 v10.4s, v26.4s, v27.4s \n" - "trn2 v11.4s, v26.4s, v27.4s \n" - "trn1 v12.4s, v28.4s, v29.4s \n" - "trn2 v13.4s, v28.4s, v29.4s \n" - "trn1 v14.4s, v30.4s, v31.4s \n" - "trn2 v15.4s, v30.4s, v31.4s \n" - - "trn1 v16.2d, v0.2d, v2.2d \n" - "trn2 v24.2d, v0.2d, v2.2d \n" - "trn1 v20.2d, v1.2d, v3.2d \n" - "trn2 v28.2d, v1.2d, v3.2d \n" - - "trn1 v17.2d, v4.2d, v6.2d \n" - "trn2 v25.2d, v4.2d, v6.2d \n" - "trn1 v21.2d, v5.2d, v7.2d \n" - "trn2 v29.2d, v5.2d, v7.2d \n" - - "trn1 v18.2d, v8.2d, v10.2d \n" - "trn2 v26.2d, v8.2d, v10.2d \n" - "trn1 v22.2d, v9.2d, v11.2d \n" - "trn2 v30.2d, v9.2d, v11.2d \n" - - "trn1 v19.2d, v12.2d, v14.2d \n" - "trn2 v27.2d, v12.2d, v14.2d \n" - "trn1 v23.2d, v13.2d, v15.2d \n" - "trn2 v31.2d, v13.2d, v15.2d \n" - - "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" - "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%1], #64 \n" - "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%2], #64 \n" - "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%3], #64 \n" + "trn1 v0.4s, v16.4s, v17.4s \n" + "trn2 v1.4s, v16.4s, v17.4s \n" + "trn1 v2.4s, v18.4s, v19.4s \n" + "trn2 v3.4s, v18.4s, v19.4s \n" + "trn1 v4.4s, v20.4s, v21.4s \n" + "trn2 v5.4s, v20.4s, v21.4s \n" + "trn1 v6.4s, v22.4s, v23.4s \n" + "trn2 v7.4s, v22.4s, v23.4s \n" + + "trn1 v16.2d, v0.2d, v2.2d \n" + "trn1 v18.2d, v1.2d, v3.2d \n" + "trn2 v20.2d, v0.2d, v2.2d \n" + "trn2 v22.2d, v1.2d, v3.2d \n" + "trn1 v17.2d, v4.2d, v6.2d \n" + "trn1 v19.2d, v5.2d, v7.2d \n" + "trn2 v21.2d, v4.2d, v6.2d \n" + "trn2 v23.2d, v5.2d, v7.2d \n" + + "trn1 v0.4s, v24.4s, v25.4s \n" + "trn2 v1.4s, v24.4s, v25.4s \n" + "trn1 v2.4s, v26.4s, v27.4s \n" + "trn2 v3.4s, v26.4s, v27.4s \n" + "trn1 v4.4s, v28.4s, v29.4s \n" + "trn2 v5.4s, v28.4s, v29.4s \n" + "trn1 v6.4s, v30.4s, v31.4s \n" + "trn2 v7.4s, v30.4s, v31.4s \n" + + "trn1 v24.2d, v0.2d, v2.2d \n" + "trn1 v26.2d, v1.2d, v3.2d \n" + "trn2 v28.2d, v0.2d, v2.2d \n" + "trn2 v30.2d, v1.2d, v3.2d \n" + "trn1 v25.2d, v4.2d, v6.2d \n" + "trn1 v27.2d, v5.2d, v7.2d \n" + "trn2 v29.2d, v4.2d, v6.2d \n" + "trn2 v31.2d, v5.2d, v7.2d \n" + + "st1 {v16.4s, v17.4s}, [%0], #32 \n" + "st1 {v18.4s, v19.4s}, [%1], #32 \n" + "st1 {v20.4s, v21.4s}, [%2], #32 \n" + "st1 {v22.4s, v23.4s}, [%3], #32 \n" + "st1 {v24.4s, v25.4s}, [%4], #32 \n" + "st1 {v26.4s, v27.4s}, [%5], #32 \n" + "st1 {v28.4s, v29.4s}, [%6], #32 \n" + "st1 {v30.4s, v31.4s}, [%7], #32 \n" : "=r"(outptr0), "=r"(outptr1), "=r"(outptr2), "=r"(outptr3), + "=r"(outptr4), + "=r"(outptr5), + "=r"(outptr6), + "=r"(outptr7), "=r"(nn), "=r"(nn4), "=r"(nn1), @@ -1074,17 +1098,566 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons "1"(outptr1), "2"(outptr2), "3"(outptr3), - "4"(nn), - "5"(nn4), - "6"(nn1), - "7"(tmpptr), - "8"(kptr0) + "4"(outptr4), + "5"(outptr5), + "6"(outptr6), + "7"(outptr7), + "8"(nn), + "9"(nn4), + "10"(nn1), + "11"(tmpptr), + "12"(kptr0) : "memory", "x4", "x5", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); } + for (; i + 3 < size; i += 4) + { + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4); + const signed char* kptr0 = kernel.channel(p / 8); + + int nn = (inch / 8) * maxk; + int nn4 = ((inch % 8) / 4) * maxk; + int nn1 = (inch % 4) * maxk; + + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + int32x4_t _sum4 = vdupq_n_s32(0); + int32x4_t _sum5 = vdupq_n_s32(0); + int32x4_t _sum6 = vdupq_n_s32(0); + int32x4_t _sum7 = vdupq_n_s32(0); + +#if __ARM_FEATURE_MATMUL_INT8 + for (int j = 0; j < nn; j++) + { + int8x16_t _val0 = vld1q_s8(tmpptr); + int8x16_t _val1 = vld1q_s8(tmpptr + 16); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + int8x16_t _w45 = vld1q_s8(kptr0 + 32); + int8x16_t _w67 = vld1q_s8(kptr0 + 48); + + _sum0 = vmmlaq_s32(_sum0, _val0, _w01); + _sum1 = vmmlaq_s32(_sum1, _val0, _w23); + _sum2 = vmmlaq_s32(_sum2, _val1, _w01); + _sum3 = vmmlaq_s32(_sum3, _val1, _w23); + + _sum4 = vmmlaq_s32(_sum4, _val0, _w45); + _sum5 = vmmlaq_s32(_sum5, _val0, _w67); + _sum6 = vmmlaq_s32(_sum6, _val1, _w45); + _sum7 = vmmlaq_s32(_sum7, _val1, _w67); + + tmpptr += 32; + kptr0 += 64; + } + + int32x4_t _sum0x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum0), vreinterpretq_s64_s32(_sum1))); + int32x4_t _sum1x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum0), vreinterpretq_s64_s32(_sum1))); + int32x4_t _sum2x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum2), vreinterpretq_s64_s32(_sum3))); + int32x4_t _sum3x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum2), vreinterpretq_s64_s32(_sum3))); + int32x4_t _sum4x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum4), vreinterpretq_s64_s32(_sum5))); + int32x4_t _sum5x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum4), vreinterpretq_s64_s32(_sum5))); + int32x4_t _sum6x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum6), vreinterpretq_s64_s32(_sum7))); + int32x4_t _sum7x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum6), vreinterpretq_s64_s32(_sum7))); + + _sum0 = _sum0x; + _sum1 = _sum1x; + _sum2 = _sum2x; + _sum3 = _sum3x; + _sum4 = _sum4x; + _sum5 = _sum5x; + _sum6 = _sum6x; + _sum7 = _sum7x; +#else // __ARM_FEATURE_MATMUL_INT8 + for (int j = 0; j < nn; j++) + { + int8x16_t _val0123_l = vld1q_s8(tmpptr); + int8x16_t _val0123_h = vld1q_s8(tmpptr + 16); + int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); + int8x16_t _w4567_l = vld1q_s8(kptr0 + 32); + int8x16_t _w4567_h = vld1q_s8(kptr0 + 48); + + _sum0 = vdotq_laneq_s32(_sum0, _w0123_l, _val0123_l, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0123_l, _val0123_l, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w0123_l, _val0123_l, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w0123_l, _val0123_l, 3); + _sum0 = vdotq_laneq_s32(_sum0, _w0123_h, _val0123_h, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0123_h, _val0123_h, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w0123_h, _val0123_h, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w0123_h, _val0123_h, 3); + + _sum4 = vdotq_laneq_s32(_sum4, _w4567_l, _val0123_l, 0); + _sum5 = vdotq_laneq_s32(_sum5, _w4567_l, _val0123_l, 1); + _sum6 = vdotq_laneq_s32(_sum6, _w4567_l, _val0123_l, 2); + _sum7 = vdotq_laneq_s32(_sum7, _w4567_l, _val0123_l, 3); + _sum4 = vdotq_laneq_s32(_sum4, _w4567_h, _val0123_h, 0); + _sum5 = vdotq_laneq_s32(_sum5, _w4567_h, _val0123_h, 1); + _sum6 = vdotq_laneq_s32(_sum6, _w4567_h, _val0123_h, 2); + _sum7 = vdotq_laneq_s32(_sum7, _w4567_h, _val0123_h, 3); + + tmpptr += 32; + kptr0 += 64; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + + for (int j = 0; j < nn4; j++) + { + int8x16_t _val0123 = vld1q_s8(tmpptr); + int8x16_t _w0 = vld1q_s8(kptr0); + int8x16_t _w1 = vld1q_s8(kptr0 + 16); + + _sum0 = vdotq_laneq_s32(_sum0, _w0, _val0123, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0, _val0123, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w0, _val0123, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w0, _val0123, 3); + + _sum4 = vdotq_laneq_s32(_sum4, _w1, _val0123, 0); + _sum5 = vdotq_laneq_s32(_sum5, _w1, _val0123, 1); + _sum6 = vdotq_laneq_s32(_sum6, _w1, _val0123, 2); + _sum7 = vdotq_laneq_s32(_sum7, _w1, _val0123, 3); + + tmpptr += 16; + kptr0 += 32; + } + + int j = 0; + for (; j + 3 < nn1; j += 4) + { + // 0123 0123 0123 0123 -> 0000111122223333 + int8x16_t _val = vld1q_s8(tmpptr); + + int8x8x2_t _val01 = vuzp_s8(vget_low_s8(_val), vget_high_s8(_val)); + int8x8x2_t _val0123 = vuzp_s8(_val01.val[0], _val01.val[1]); + int8x16_t _val0123f = vcombine_s8(_val0123.val[0], _val0123.val[1]); + + // 0123 4567 0123 4567 0123 4567 0123 4567 -> 0000111122223333 + int32x4x2_t _w = vld2q_s32((const int*)kptr0); + + int8x16_t _w0 = vreinterpretq_s8_s32(_w.val[0]); + int8x16_t _w1 = vreinterpretq_s8_s32(_w.val[1]); + + int8x8x2_t _w01 = vuzp_s8(vget_low_s8(_w0), vget_high_s8(_w0)); + int8x8x2_t _w0123 = vuzp_s8(_w01.val[0], _w01.val[1]); + int8x16_t _w0123f = vcombine_s8(_w0123.val[0], _w0123.val[1]); + + int8x8x2_t _w45 = vuzp_s8(vget_low_s8(_w1), vget_high_s8(_w1)); + int8x8x2_t _w4567 = vuzp_s8(_w45.val[0], _w45.val[1]); + int8x16_t _w4567f = vcombine_s8(_w4567.val[0], _w4567.val[1]); + + _sum0 = vdotq_laneq_s32(_sum0, _w0123f, _val0123f, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0123f, _val0123f, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w0123f, _val0123f, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w0123f, _val0123f, 3); + + _sum4 = vdotq_laneq_s32(_sum4, _w4567f, _val0123f, 0); + _sum5 = vdotq_laneq_s32(_sum5, _w4567f, _val0123f, 1); + _sum6 = vdotq_laneq_s32(_sum6, _w4567f, _val0123f, 2); + _sum7 = vdotq_laneq_s32(_sum7, _w4567f, _val0123f, 3); + + tmpptr += 16; + kptr0 += 32; + } + for (; j < nn1; j++) + { + int16x4_t _val0 = vdup_n_s16(tmpptr[0]); + int16x4_t _val1 = vdup_n_s16(tmpptr[1]); + int16x4_t _val2 = vdup_n_s16(tmpptr[2]); + int16x4_t _val3 = vdup_n_s16(tmpptr[3]); + + int16x8_t _w01 = vmovl_s8(vld1_s8(kptr0)); + + _sum0 = vmlal_s16(_sum0, _val0, vget_low_s16(_w01)); + _sum1 = vmlal_s16(_sum1, _val1, vget_low_s16(_w01)); + _sum2 = vmlal_s16(_sum2, _val2, vget_low_s16(_w01)); + _sum3 = vmlal_s16(_sum3, _val3, vget_low_s16(_w01)); + + _sum4 = vmlal_s16(_sum4, _val0, vget_high_s16(_w01)); + _sum5 = vmlal_s16(_sum5, _val1, vget_high_s16(_w01)); + _sum6 = vmlal_s16(_sum6, _val2, vget_high_s16(_w01)); + _sum7 = vmlal_s16(_sum7, _val3, vget_high_s16(_w01)); + + tmpptr += 4; + kptr0 += 8; + } + + // transpose 4x4 + int32x4_t _sum01_0 = vtrn1q_s32(_sum0, _sum1); + int32x4_t _sum01_1 = vtrn2q_s32(_sum0, _sum1); + int32x4_t _sum23_0 = vtrn1q_s32(_sum2, _sum3); + int32x4_t _sum23_1 = vtrn2q_s32(_sum2, _sum3); + int32x4_t _sum45_0 = vtrn1q_s32(_sum4, _sum5); + int32x4_t _sum45_1 = vtrn2q_s32(_sum4, _sum5); + int32x4_t _sum67_0 = vtrn1q_s32(_sum6, _sum7); + int32x4_t _sum67_1 = vtrn2q_s32(_sum6, _sum7); + _sum0 = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum01_0), vreinterpretq_s64_s32(_sum23_0))); + _sum1 = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum01_1), vreinterpretq_s64_s32(_sum23_1))); + _sum2 = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum01_0), vreinterpretq_s64_s32(_sum23_0))); + _sum3 = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum01_1), vreinterpretq_s64_s32(_sum23_1))); + _sum4 = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum45_0), vreinterpretq_s64_s32(_sum67_0))); + _sum5 = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum45_1), vreinterpretq_s64_s32(_sum67_1))); + _sum6 = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum45_0), vreinterpretq_s64_s32(_sum67_0))); + _sum7 = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum45_1), vreinterpretq_s64_s32(_sum67_1))); + + vst1q_s32(outptr0, _sum0); + vst1q_s32(outptr1, _sum1); + vst1q_s32(outptr2, _sum2); + vst1q_s32(outptr3, _sum3); + vst1q_s32(outptr4, _sum4); + vst1q_s32(outptr5, _sum5); + vst1q_s32(outptr6, _sum6); + vst1q_s32(outptr7, _sum7); + outptr0 += 4; + outptr1 += 4; + outptr2 += 4; + outptr3 += 4; + outptr4 += 4; + outptr5 += 4; + outptr6 += 4; + outptr7 += 4; + } + for (; i + 1 < size; i += 2) + { + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2); + const signed char* kptr0 = kernel.channel(p / 8); + + int nn = (inch / 8) * maxk; + int nn4 = ((inch % 8) / 4) * maxk; + int nn1 = (inch % 4) * maxk; + + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + +#if __ARM_FEATURE_MATMUL_INT8 + for (int j = 0; j < nn; j++) + { + int8x16_t _val = vld1q_s8(tmpptr); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + int8x16_t _w45 = vld1q_s8(kptr0 + 32); + int8x16_t _w67 = vld1q_s8(kptr0 + 48); + + _sum0 = vmmlaq_s32(_sum0, _val, _w01); + _sum1 = vmmlaq_s32(_sum1, _val, _w23); + _sum2 = vmmlaq_s32(_sum2, _val, _w45); + _sum3 = vmmlaq_s32(_sum3, _val, _w67); + + tmpptr += 16; + kptr0 += 64; + } + + int32x4_t _sum0x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum0), vreinterpretq_s64_s32(_sum1))); + int32x4_t _sum1x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum0), vreinterpretq_s64_s32(_sum1))); + int32x4_t _sum2x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum2), vreinterpretq_s64_s32(_sum3))); + int32x4_t _sum3x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum2), vreinterpretq_s64_s32(_sum3))); + + _sum0 = _sum0x; + _sum1 = _sum1x; + _sum2 = _sum2x; + _sum3 = _sum3x; +#else // __ARM_FEATURE_MATMUL_INT8 + for (int j = 0; j < nn; j++) + { + int8x16_t _val01_l_h = vld1q_s8(tmpptr); + int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); + int8x16_t _w4567_l = vld1q_s8(kptr0 + 32); + int8x16_t _w4567_h = vld1q_s8(kptr0 + 48); + + _sum0 = vdotq_laneq_s32(_sum0, _w0123_l, _val01_l_h, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0123_l, _val01_l_h, 1); + _sum0 = vdotq_laneq_s32(_sum0, _w0123_h, _val01_l_h, 2); + _sum1 = vdotq_laneq_s32(_sum1, _w0123_h, _val01_l_h, 3); + + _sum2 = vdotq_laneq_s32(_sum2, _w4567_l, _val01_l_h, 0); + _sum3 = vdotq_laneq_s32(_sum3, _w4567_l, _val01_l_h, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w4567_h, _val01_l_h, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w4567_h, _val01_l_h, 3); + + tmpptr += 16; + kptr0 += 64; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + + if (nn4 > 0) + { + int j = 0; + for (; j + 1 < nn4; j += 2) + { + int8x16_t _val0123 = vld1q_s8(tmpptr); + int8x16_t _w0 = vld1q_s8(kptr0); + int8x16_t _w1 = vld1q_s8(kptr0 + 16); + int8x16_t _w2 = vld1q_s8(kptr0 + 32); + int8x16_t _w3 = vld1q_s8(kptr0 + 48); + + _sum0 = vdotq_laneq_s32(_sum0, _w0, _val0123, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0, _val0123, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w1, _val0123, 0); + _sum3 = vdotq_laneq_s32(_sum3, _w1, _val0123, 1); + + _sum0 = vdotq_laneq_s32(_sum0, _w2, _val0123, 2); + _sum1 = vdotq_laneq_s32(_sum1, _w2, _val0123, 3); + _sum2 = vdotq_laneq_s32(_sum2, _w3, _val0123, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w3, _val0123, 3); + + tmpptr += 16; + kptr0 += 64; + } + for (; j < nn4; j++) + { + int8x8_t _val01 = vld1_s8(tmpptr); + int8x16_t _w0 = vld1q_s8(kptr0); + int8x16_t _w1 = vld1q_s8(kptr0 + 16); + + _sum0 = vdotq_lane_s32(_sum0, _w0, _val01, 0); + _sum1 = vdotq_lane_s32(_sum1, _w0, _val01, 1); + _sum2 = vdotq_lane_s32(_sum2, _w1, _val01, 0); + _sum3 = vdotq_lane_s32(_sum3, _w1, _val01, 1); + + tmpptr += 8; + kptr0 += 32; + } + } + + int j = 0; + for (; j + 3 < nn1; j += 4) + { + int16x8_t _val01234567 = vmovl_s8(vld1_s8(tmpptr)); + + int8x16_t _w0 = vld1q_s8(kptr0); + int8x16_t _w1 = vld1q_s8(kptr0 + 16); + int16x8_t _w0l = vmovl_s8(vget_low_s8(_w0)); + int16x8_t _w0h = vmovl_s8(vget_high_s8(_w0)); + int16x8_t _w1l = vmovl_s8(vget_low_s8(_w1)); + int16x8_t _w1h = vmovl_s8(vget_high_s8(_w1)); + + _sum0 = vmlal_laneq_s16(_sum0, vget_low_s16(_w0l), _val01234567, 0); + _sum1 = vmlal_laneq_s16(_sum1, vget_low_s16(_w0l), _val01234567, 1); + _sum2 = vmlal_laneq_s16(_sum2, vget_high_s16(_w0l), _val01234567, 0); + _sum3 = vmlal_laneq_s16(_sum3, vget_high_s16(_w0l), _val01234567, 1); + + _sum0 = vmlal_laneq_s16(_sum0, vget_low_s16(_w0h), _val01234567, 2); + _sum1 = vmlal_laneq_s16(_sum1, vget_low_s16(_w0h), _val01234567, 3); + _sum2 = vmlal_laneq_s16(_sum2, vget_high_s16(_w0h), _val01234567, 2); + _sum3 = vmlal_laneq_s16(_sum3, vget_high_s16(_w0h), _val01234567, 3); + + _sum0 = vmlal_laneq_s16(_sum0, vget_low_s16(_w1l), _val01234567, 4); + _sum1 = vmlal_laneq_s16(_sum1, vget_low_s16(_w1l), _val01234567, 5); + _sum2 = vmlal_laneq_s16(_sum2, vget_high_s16(_w1l), _val01234567, 4); + _sum3 = vmlal_laneq_s16(_sum3, vget_high_s16(_w1l), _val01234567, 5); + + _sum0 = vmlal_laneq_s16(_sum0, vget_low_s16(_w1h), _val01234567, 6); + _sum1 = vmlal_laneq_s16(_sum1, vget_low_s16(_w1h), _val01234567, 7); + _sum2 = vmlal_laneq_s16(_sum2, vget_high_s16(_w1h), _val01234567, 6); + _sum3 = vmlal_laneq_s16(_sum3, vget_high_s16(_w1h), _val01234567, 7); + + tmpptr += 8; + kptr0 += 32; + } + for (; j < nn1; j++) + { + int16x4_t _val0 = vdup_n_s16(tmpptr[0]); + int16x4_t _val1 = vdup_n_s16(tmpptr[1]); + int16x8_t _w01 = vmovl_s8(vld1_s8(kptr0)); + + _sum0 = vmlal_s16(_sum0, _val0, vget_low_s16(_w01)); + _sum1 = vmlal_s16(_sum1, _val1, vget_low_s16(_w01)); + _sum2 = vmlal_s16(_sum2, _val0, vget_high_s16(_w01)); + _sum3 = vmlal_s16(_sum3, _val1, vget_high_s16(_w01)); + + tmpptr += 2; + kptr0 += 8; + } + + int32x4x2_t _sum01 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _sum23 = vzipq_s32(_sum2, _sum3); + + vst1_s32(outptr0, vget_low_s32(_sum01.val[0])); + vst1_s32(outptr1, vget_high_s32(_sum01.val[0])); + vst1_s32(outptr2, vget_low_s32(_sum01.val[1])); + vst1_s32(outptr3, vget_high_s32(_sum01.val[1])); + vst1_s32(outptr4, vget_low_s32(_sum23.val[0])); + vst1_s32(outptr5, vget_high_s32(_sum23.val[0])); + vst1_s32(outptr6, vget_low_s32(_sum23.val[1])); + vst1_s32(outptr7, vget_high_s32(_sum23.val[1])); + outptr0 += 2; + outptr1 += 2; + outptr2 += 2; + outptr3 += 2; + outptr4 += 2; + outptr5 += 2; + outptr6 += 2; + outptr7 += 2; + } + for (; i < size; i++) + { + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + const signed char* kptr0 = kernel.channel(p / 8); + + int nn = (inch / 8) * maxk; + int nn4 = ((inch % 8) / 4) * maxk; + int nn1 = (inch % 4) * maxk; + +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum23 = vdupq_n_s32(0); + int32x4_t _sum45 = vdupq_n_s32(0); + int32x4_t _sum67 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x8_t _val0 = vld1_s8(tmpptr); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + int8x16_t _w45 = vld1q_s8(kptr0 + 32); + int8x16_t _w67 = vld1q_s8(kptr0 + 48); + + int8x16_t _val = vcombine_s8(_val0, _val0); + + _sum01 = vdotq_s32(_sum01, _val, _w01); + _sum23 = vdotq_s32(_sum23, _val, _w23); + _sum45 = vdotq_s32(_sum45, _val, _w45); + _sum67 = vdotq_s32(_sum67, _val, _w67); + + tmpptr += 8; + kptr0 += 64; + } + + int32x4_t _sum0 = vpaddq_s32(_sum01, _sum23); + int32x4_t _sum1 = vpaddq_s32(_sum45, _sum67); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x8_t _val0_l_h = vld1_s8(tmpptr); + int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); + int8x16_t _w4567_l = vld1q_s8(kptr0 + 32); + int8x16_t _w4567_h = vld1q_s8(kptr0 + 48); + + _sum0 = vdotq_lane_s32(_sum0, _w0123_l, _val0_l_h, 0); + _sum0 = vdotq_lane_s32(_sum0, _w0123_h, _val0_l_h, 1); + _sum1 = vdotq_lane_s32(_sum1, _w4567_l, _val0_l_h, 0); + _sum1 = vdotq_lane_s32(_sum1, _w4567_h, _val0_l_h, 1); + + tmpptr += 8; + kptr0 += 64; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + + if (nn4 > 0) + { + int j = 0; + for (; j + 1 < nn4; j += 2) + { + int8x8_t _val01 = vld1_s8(tmpptr); + int8x16_t _w0 = vld1q_s8(kptr0); + int8x16_t _w1 = vld1q_s8(kptr0 + 16); + int8x16_t _w2 = vld1q_s8(kptr0 + 32); + int8x16_t _w3 = vld1q_s8(kptr0 + 48); + + _sum0 = vdotq_lane_s32(_sum0, _w0, _val01, 0); + _sum1 = vdotq_lane_s32(_sum1, _w1, _val01, 0); + _sum0 = vdotq_lane_s32(_sum0, _w2, _val01, 1); + _sum1 = vdotq_lane_s32(_sum1, _w3, _val01, 1); + + tmpptr += 8; + kptr0 += 64; + } + for (; j < nn4; j++) + { + int8x8_t _val_xxx = vld1_s8(tmpptr); + int8x16_t _w0 = vld1q_s8(kptr0); + int8x16_t _w1 = vld1q_s8(kptr0 + 16); + + _sum0 = vdotq_lane_s32(_sum0, _w0, _val_xxx, 0); + _sum1 = vdotq_lane_s32(_sum1, _w1, _val_xxx, 0); + + tmpptr += 4; + kptr0 += 32; + } + } + + int j = 0; + for (; j + 3 < nn1; j += 4) + { + int16x4_t _val0123 = vget_low_s16(vmovl_s8(vld1_s8(tmpptr))); + + int8x16_t _w0 = vld1q_s8(kptr0); + int8x16_t _w1 = vld1q_s8(kptr0 + 16); + int16x8_t _w0l = vmovl_s8(vget_low_s8(_w0)); + int16x8_t _w0h = vmovl_s8(vget_high_s8(_w0)); + int16x8_t _w1l = vmovl_s8(vget_low_s8(_w1)); + int16x8_t _w1h = vmovl_s8(vget_high_s8(_w1)); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0l), _val0123, 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0l), _val0123, 0); + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0h), _val0123, 1); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0h), _val0123, 1); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1l), _val0123, 2); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1l), _val0123, 2); + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1h), _val0123, 3); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1h), _val0123, 3); + + tmpptr += 4; + kptr0 += 32; + } + for (; j < nn1; j++) + { + int16x4_t _val = vdup_n_s16(tmpptr[0]); + int16x8_t _w01 = vmovl_s8(vld1_s8(kptr0)); + + _sum0 = vmlal_s16(_sum0, _val, vget_low_s16(_w01)); + _sum1 = vmlal_s16(_sum1, _val, vget_high_s16(_w01)); + + tmpptr += 1; + kptr0 += 8; + } + + outptr0[0] = vgetq_lane_s32(_sum0, 0); + outptr1[0] = vgetq_lane_s32(_sum0, 1); + outptr2[0] = vgetq_lane_s32(_sum0, 2); + outptr3[0] = vgetq_lane_s32(_sum0, 3); + outptr4[0] = vgetq_lane_s32(_sum1, 0); + outptr5[0] = vgetq_lane_s32(_sum1, 1); + outptr6[0] = vgetq_lane_s32(_sum1, 2); + outptr7[0] = vgetq_lane_s32(_sum1, 3); + outptr0 += 1; + outptr1 += 1; + outptr2 += 1; + outptr3 += 1; + outptr4 += 1; + outptr5 += 1; + outptr6 += 1; + outptr7 += 1; + } + } +#endif // __ARM_FEATURE_DOTPROD + + nn_outch = (outch - remain_outch_start) >> 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int pp = 0; pp < nn_outch; pp++) + { + int p = remain_outch_start + pp * 4; + + int* outptr0 = top_blob.channel(p); + int* outptr1 = top_blob.channel(p + 1); + int* outptr2 = top_blob.channel(p + 2); + int* outptr3 = top_blob.channel(p + 3); + + int i = 0; +#if __aarch64__ +#if __ARM_FEATURE_DOTPROD for (; i + 7 < size; i += 8) { - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8); - const signed char* kptr0 = kernel.channel(p / 4); + const signed char* tmpptr = tmp.channel(i / 8); + const signed char* kptr0 = kernel.channel(p / 8 + (p % 8) / 4); int nn = (inch / 8) * maxk; int nn4 = ((inch % 8) / 4) * maxk; @@ -1099,6 +1672,48 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons int32x4_t _sum6 = vdupq_n_s32(0); int32x4_t _sum7 = vdupq_n_s32(0); +#if __ARM_FEATURE_MATMUL_INT8 + for (int j = 0; j < nn; j++) + { + int8x16_t _val0 = vld1q_s8(tmpptr); + int8x16_t _val1 = vld1q_s8(tmpptr + 16); + int8x16_t _val2 = vld1q_s8(tmpptr + 32); + int8x16_t _val3 = vld1q_s8(tmpptr + 48); + + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + + _sum0 = vmmlaq_s32(_sum0, _val0, _w01); + _sum1 = vmmlaq_s32(_sum1, _val0, _w23); + _sum2 = vmmlaq_s32(_sum2, _val1, _w01); + _sum3 = vmmlaq_s32(_sum3, _val1, _w23); + _sum4 = vmmlaq_s32(_sum4, _val2, _w01); + _sum5 = vmmlaq_s32(_sum5, _val2, _w23); + _sum6 = vmmlaq_s32(_sum6, _val3, _w01); + _sum7 = vmmlaq_s32(_sum7, _val3, _w23); + + tmpptr += 64; + kptr0 += 32; + } + + int32x4_t _sum0x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum0), vreinterpretq_s64_s32(_sum1))); + int32x4_t _sum1x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum0), vreinterpretq_s64_s32(_sum1))); + int32x4_t _sum2x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum2), vreinterpretq_s64_s32(_sum3))); + int32x4_t _sum3x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum2), vreinterpretq_s64_s32(_sum3))); + int32x4_t _sum4x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum4), vreinterpretq_s64_s32(_sum5))); + int32x4_t _sum5x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum4), vreinterpretq_s64_s32(_sum5))); + int32x4_t _sum6x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum6), vreinterpretq_s64_s32(_sum7))); + int32x4_t _sum7x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum6), vreinterpretq_s64_s32(_sum7))); + + _sum0 = _sum0x; + _sum1 = _sum1x; + _sum2 = _sum2x; + _sum3 = _sum3x; + _sum4 = _sum4x; + _sum5 = _sum5x; + _sum6 = _sum6x; + _sum7 = _sum7x; +#else // __ARM_FEATURE_MATMUL_INT8 for (int j = 0; j < nn; j++) { int8x16_t _val0123_l = vld1q_s8(tmpptr); @@ -1132,6 +1747,7 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons tmpptr += 64; kptr0 += 32; } +#endif // __ARM_FEATURE_MATMUL_INT8 for (int j = 0; j < nn4; j++) { @@ -1242,11 +1858,12 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons for (; i + 3 < size; i += 4) { #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4); + const signed char* kptr0 = kernel.channel(p / 8 + (p % 8) / 4); #else const signed char* tmpptr = tmp.channel(i / 4); -#endif const signed char* kptr0 = kernel.channel(p / 4); +#endif int nn = (inch / 8) * maxk; int nn4 = ((inch % 8) / 4) * maxk; @@ -1257,6 +1874,33 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons int32x4_t _sum2 = vdupq_n_s32(0); int32x4_t _sum3 = vdupq_n_s32(0); +#if __ARM_FEATURE_MATMUL_INT8 + for (int j = 0; j < nn; j++) + { + int8x16_t _val0 = vld1q_s8(tmpptr); + int8x16_t _val1 = vld1q_s8(tmpptr + 16); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + + _sum0 = vmmlaq_s32(_sum0, _val0, _w01); + _sum1 = vmmlaq_s32(_sum1, _val0, _w23); + _sum2 = vmmlaq_s32(_sum2, _val1, _w01); + _sum3 = vmmlaq_s32(_sum3, _val1, _w23); + + tmpptr += 32; + kptr0 += 32; + } + + int32x4_t _sum0x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum0), vreinterpretq_s64_s32(_sum1))); + int32x4_t _sum1x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum0), vreinterpretq_s64_s32(_sum1))); + int32x4_t _sum2x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum2), vreinterpretq_s64_s32(_sum3))); + int32x4_t _sum3x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum2), vreinterpretq_s64_s32(_sum3))); + + _sum0 = _sum0x; + _sum1 = _sum1x; + _sum2 = _sum2x; + _sum3 = _sum3x; +#else // __ARM_FEATURE_MATMUL_INT8 for (int j = 0; j < nn; j++) { int8x16_t _val0123_l = vld1q_s8(tmpptr); @@ -1278,6 +1922,7 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons tmpptr += 32; kptr0 += 32; } +#endif // __ARM_FEATURE_MATMUL_INT8 for (int j = 0; j < nn4; j++) { @@ -1770,14 +2415,16 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons { #if __aarch64__ #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2); + const signed char* kptr0 = kernel.channel(p / 8 + (p % 8) / 4); #else const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); + const signed char* kptr0 = kernel.channel(p / 4); #endif #else const signed char* tmpptr = tmp.channel(i / 2); -#endif const signed char* kptr0 = kernel.channel(p / 4); +#endif int nn = (inch / 8) * maxk; int nn4 = ((inch % 8) / 4) * maxk; @@ -1786,6 +2433,26 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons int32x4_t _sum00 = vdupq_n_s32(0); int32x4_t _sum10 = vdupq_n_s32(0); #if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (int j = 0; j < nn; j++) + { + int8x16_t _val = vld1q_s8(tmpptr); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + + _sum00 = vmmlaq_s32(_sum00, _val, _w01); + _sum10 = vmmlaq_s32(_sum10, _val, _w23); + + tmpptr += 16; + kptr0 += 32; + } + + int32x4_t _sum00x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum00), vreinterpretq_s64_s32(_sum10))); + int32x4_t _sum10x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum00), vreinterpretq_s64_s32(_sum10))); + + _sum00 = _sum00x; + _sum10 = _sum10x; +#else // __ARM_FEATURE_MATMUL_INT8 for (int j = 0; j < nn; j++) { int8x16_t _val01_l_h = vld1q_s8(tmpptr); @@ -1802,6 +2469,7 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons tmpptr += 16; kptr0 += 32; } +#endif // __ARM_FEATURE_MATMUL_INT8 if (nn4 > 0) { @@ -2335,14 +3003,16 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons { #if __aarch64__ #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + const signed char* kptr0 = kernel.channel(p / 8 + (p % 8) / 4); #else const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); + const signed char* kptr0 = kernel.channel(p / 4); #endif #else const signed char* tmpptr = tmp.channel(i / 2 + i % 2); -#endif const signed char* kptr0 = kernel.channel(p / 4); +#endif int nn = (inch / 8) * maxk; int nn4 = ((inch % 8) / 4) * maxk; @@ -2350,6 +3020,26 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons int32x4_t _sum0 = vdupq_n_s32(0); #if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum23 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x8_t _val0 = vld1_s8(tmpptr); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + + int8x16_t _val = vcombine_s8(_val0, _val0); + + _sum0 = vdotq_s32(_sum0, _val, _w01); + _sum23 = vdotq_s32(_sum23, _val, _w23); + + tmpptr += 8; + kptr0 += 32; + } + + _sum0 = vpaddq_s32(_sum0, _sum23); +#else // __ARM_FEATURE_MATMUL_INT8 for (int j = 0; j < nn; j++) { int8x8_t _val0_l_h = vld1_s8(tmpptr); @@ -2365,6 +3055,7 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons tmpptr += 8; kptr0 += 32; } +#endif // __ARM_FEATURE_MATMUL_INT8 if (nn4 > 0) { @@ -2753,12 +3444,12 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons int* outptr0 = top_blob.channel(p); int i = 0; -#if __ARM_NEON && __aarch64__ +#if __aarch64__ #if __ARM_FEATURE_DOTPROD - for (; i + 15 < size; i += 16) + for (; i + 7 < size; i += 8) { - const signed char* tmpptr = tmp.channel(i / 16); - const signed char* kptr0 = kernel.channel(p / 4 + p % 4); + const signed char* tmpptr = tmp.channel(i / 8); + const signed char* kptr0 = kernel.channel(p / 8 + (p % 8) / 4 + p % 4); int nn = (inch / 8) * maxk; int nn4 = ((inch % 8) / 4) * maxk; @@ -2766,102 +3457,42 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons int32x4_t _sum0 = vdupq_n_s32(0); int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int8x16_t _val0123_l = vld1q_s8(tmpptr); - int8x16_t _val4567_l = vld1q_s8(tmpptr + 16); - int8x16_t _val89ab_l = vld1q_s8(tmpptr + 32); - int8x16_t _valcdef_l = vld1q_s8(tmpptr + 48); - int8x16_t _val0123_h = vld1q_s8(tmpptr + 64); - int8x16_t _val4567_h = vld1q_s8(tmpptr + 80); - int8x16_t _val89ab_h = vld1q_s8(tmpptr + 96); - int8x16_t _valcdef_h = vld1q_s8(tmpptr + 112); - int8x8_t _w_lh = vld1_s8(kptr0); - - _sum0 = vdotq_lane_s32(_sum0, _val0123_l, _w_lh, 0); - _sum1 = vdotq_lane_s32(_sum1, _val4567_l, _w_lh, 0); - _sum2 = vdotq_lane_s32(_sum2, _val89ab_l, _w_lh, 0); - _sum3 = vdotq_lane_s32(_sum3, _valcdef_l, _w_lh, 0); - _sum0 = vdotq_lane_s32(_sum0, _val0123_h, _w_lh, 1); - _sum1 = vdotq_lane_s32(_sum1, _val4567_h, _w_lh, 1); - _sum2 = vdotq_lane_s32(_sum2, _val89ab_h, _w_lh, 1); - _sum3 = vdotq_lane_s32(_sum3, _valcdef_h, _w_lh, 1); - - tmpptr += 128; - kptr0 += 8; - } - - if (nn4 > 0) + if (nn > 0) { - int32x4_t _sum4 = vdupq_n_s32(0); - int32x4_t _sum5 = vdupq_n_s32(0); - int32x4_t _sum6 = vdupq_n_s32(0); - int32x4_t _sum7 = vdupq_n_s32(0); +#if __ARM_FEATURE_MATMUL_INT8 + int32x2_t _s0 = vdup_n_s32(0); + int32x2_t _s1 = vdup_n_s32(0); + int32x2_t _s2 = vdup_n_s32(0); + int32x2_t _s3 = vdup_n_s32(0); + int32x2_t _s4 = vdup_n_s32(0); + int32x2_t _s5 = vdup_n_s32(0); + int32x2_t _s6 = vdup_n_s32(0); + int32x2_t _s7 = vdup_n_s32(0); - for (int j = 0; j < nn4; j++) + for (int j = 0; j < nn; j++) { int8x16_t _val0 = vld1q_s8(tmpptr); int8x16_t _val1 = vld1q_s8(tmpptr + 16); int8x16_t _val2 = vld1q_s8(tmpptr + 32); int8x16_t _val3 = vld1q_s8(tmpptr + 48); + int8x8_t _w = vld1_s8(kptr0); - int8x8_t _w_0123_xxxx = vld1_s8(kptr0); - - _sum4 = vdotq_lane_s32(_sum4, _val0, _w_0123_xxxx, 0); - _sum5 = vdotq_lane_s32(_sum5, _val1, _w_0123_xxxx, 0); - _sum6 = vdotq_lane_s32(_sum6, _val2, _w_0123_xxxx, 0); - _sum7 = vdotq_lane_s32(_sum7, _val3, _w_0123_xxxx, 0); + _s0 = vdot_s32(_s0, vget_low_s8(_val0), _w); + _s1 = vdot_s32(_s1, vget_high_s8(_val0), _w); + _s2 = vdot_s32(_s2, vget_low_s8(_val1), _w); + _s3 = vdot_s32(_s3, vget_high_s8(_val1), _w); + _s4 = vdot_s32(_s4, vget_low_s8(_val2), _w); + _s5 = vdot_s32(_s5, vget_high_s8(_val2), _w); + _s6 = vdot_s32(_s6, vget_low_s8(_val3), _w); + _s7 = vdot_s32(_s7, vget_high_s8(_val3), _w); tmpptr += 64; - kptr0 += 4; + kptr0 += 8; } - _sum0 = vaddq_s32(_sum0, _sum4); - _sum1 = vaddq_s32(_sum1, _sum5); - _sum2 = vaddq_s32(_sum2, _sum6); - _sum3 = vaddq_s32(_sum3, _sum7); - } - - int j = 0; - for (; j < nn1; j++) - { - int8x16_t _val = vld1q_s8(tmpptr); - int8x8_t _w = vld1_dup_s8(kptr0); - - int16x8_t _s0 = vmull_s8(vget_low_s8(_val), _w); - int16x8_t _s1 = vmull_s8(vget_high_s8(_val), _w); - - _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); - _sum1 = vaddw_s16(_sum1, vget_high_s16(_s0)); - _sum2 = vaddw_s16(_sum2, vget_low_s16(_s1)); - _sum3 = vaddw_s16(_sum3, vget_high_s16(_s1)); - - tmpptr += 16; - kptr0 += 1; - } - - vst1q_s32(outptr0, _sum0); - vst1q_s32(outptr0 + 4, _sum1); - vst1q_s32(outptr0 + 8, _sum2); - vst1q_s32(outptr0 + 12, _sum3); - outptr0 += 16; - } - for (; i + 7 < size; i += 8) - { - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8); - const signed char* kptr0 = kernel.channel(p / 4 + p % 4); - - int nn = (inch / 8) * maxk; - int nn4 = ((inch % 8) / 4) * maxk; - int nn1 = (inch % 4) * maxk; - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - if (nn > 0) - { + _sum0 = vpaddq_s32(vcombine_s32(_s0, _s1), vcombine_s32(_s2, _s3)); + _sum1 = vpaddq_s32(vcombine_s32(_s4, _s5), vcombine_s32(_s6, _s7)); +#else // __ARM_FEATURE_MATMUL_INT8 int32x4_t _sum2 = vdupq_n_s32(0); int32x4_t _sum3 = vdupq_n_s32(0); @@ -2884,6 +3515,7 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons _sum0 = vaddq_s32(_sum0, _sum2); _sum1 = vaddq_s32(_sum1, _sum3); +#endif // __ARM_FEATURE_MATMUL_INT8 } if (nn4 > 0) @@ -2932,11 +3564,12 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons for (; i + 3 < size; i += 4) { #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4); + const signed char* kptr0 = kernel.channel(p / 8 + (p % 8) / 4 + p % 4); #else const signed char* tmpptr = tmp.channel(i / 4); -#endif const signed char* kptr0 = kernel.channel(p / 4 + p % 4); +#endif int nn = (inch / 8) * maxk; int nn4 = ((inch % 8) / 4) * maxk; @@ -2946,6 +3579,30 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons if (nn > 0) { #if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int32x2_t _s0 = vdup_n_s32(0); + int32x2_t _s1 = vdup_n_s32(0); + int32x2_t _s2 = vdup_n_s32(0); + int32x2_t _s3 = vdup_n_s32(0); + + int j = 0; + for (; j < nn; j++) + { + int8x16_t _val0 = vld1q_s8(tmpptr); + int8x16_t _val1 = vld1q_s8(tmpptr + 16); + int8x8_t _w = vld1_s8(kptr0); + + _s0 = vdot_s32(_s0, vget_low_s8(_val0), _w); + _s1 = vdot_s32(_s1, vget_high_s8(_val0), _w); + _s2 = vdot_s32(_s2, vget_low_s8(_val1), _w); + _s3 = vdot_s32(_s3, vget_high_s8(_val1), _w); + + tmpptr += 32; + kptr0 += 8; + } + + _sum0 = vpaddq_s32(vcombine_s32(_s0, _s1), vcombine_s32(_s2, _s3)); +#else // __ARM_FEATURE_MATMUL_INT8 int32x4_t _sum1 = vdupq_n_s32(0); int j = 0; @@ -2963,6 +3620,7 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons } _sum0 = vaddq_s32(_sum0, _sum1); +#endif // __ARM_FEATURE_MATMUL_INT8 #else // __ARM_FEATURE_DOTPROD int32x4_t _sum1 = vdupq_n_s32(0); int32x4_t _sum2 = vdupq_n_s32(0); @@ -3134,23 +3792,25 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons vst1q_s32(outptr0, _sum0); outptr0 += 4; } -#endif // __ARM_NEON && __aarch64__ +#endif // __aarch64__ for (; i + 1 < size; i += 2) { -#if __ARM_NEON && __aarch64__ +#if __aarch64__ #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2); + const signed char* kptr0 = kernel.channel(p / 8 + (p % 8) / 4 + p % 4); #else const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); + const signed char* kptr0 = kernel.channel(p / 4 + p % 4); #endif #else const signed char* tmpptr = tmp.channel(i / 2); -#endif #if __ARM_NEON const signed char* kptr0 = kernel.channel(p / 4 + p % 4); #else const signed char* kptr0 = kernel.channel(p / 2 + p % 2); #endif +#endif #if __ARM_NEON int nn = (inch / 8) * maxk; @@ -3161,6 +3821,25 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons if (nn > 0) { #if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int32x2_t _s0 = vdup_n_s32(0); + int32x2_t _s1 = vdup_n_s32(0); + + int j = 0; + for (; j < nn; j++) + { + int8x16_t _val = vld1q_s8(tmpptr); + int8x8_t _w = vld1_s8(kptr0); + + _s0 = vdot_s32(_s0, vget_low_s8(_val), _w); + _s1 = vdot_s32(_s1, vget_high_s8(_val), _w); + + tmpptr += 16; + kptr0 += 8; + } + + _sum = vpadd_s32(_s0, _s1); +#else // __ARM_FEATURE_MATMUL_INT8 int32x2_t _sum0 = vdup_n_s32(0); int32x2_t _sum1 = vdup_n_s32(0); @@ -3178,6 +3857,7 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons } _sum = vadd_s32(_sum0, _sum1); +#endif // __ARM_FEATURE_MATMUL_INT8 #else // __ARM_FEATURE_DOTPROD int32x4_t _sum0 = vdupq_n_s32(0); int32x4_t _sum1 = vdupq_n_s32(0); @@ -3298,19 +3978,21 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons } for (; i < size; i++) { -#if __ARM_NEON && __aarch64__ +#if __aarch64__ #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + const signed char* kptr0 = kernel.channel(p / 8 + (p % 8) / 4 + p % 4); #else const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); + const signed char* kptr0 = kernel.channel(p / 4 + p % 4); #endif #else const signed char* tmpptr = tmp.channel(i / 2 + i % 2); -#endif #if __ARM_NEON const signed char* kptr0 = kernel.channel(p / 4 + p % 4); #else const signed char* kptr0 = kernel.channel(p / 2 + p % 2); +#endif #endif int sum = 0; @@ -3440,12 +4122,22 @@ static void im2col_sgemm_int8_neon(const Mat& bottom_im2col, Mat& top_blob, cons static void convolution_im2col_sgemm_transform_kernel_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) { -#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __ARM_NEON && __aarch64__ && !__ARM_FEATURE_DOTPROD +#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + convolution_im2col_sgemm_transform_kernel_int8_neon_i8mm(_kernel, kernel_tm, inch, outch, kernel_w, kernel_h); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD if (ncnn::cpu_support_arm_asimddp()) { convolution_im2col_sgemm_transform_kernel_int8_neon_asimddp(_kernel, kernel_tm, inch, outch, kernel_w, kernel_h); return; } +#endif #endif const int maxk = kernel_w * kernel_h; @@ -3453,9 +4145,30 @@ static void convolution_im2col_sgemm_transform_kernel_int8_neon(const Mat& _kern // interleave // src = maxk-inch-outch // dst = 8a-4b-maxk-inch/8a-outch/4b - // dst = 4a-4b-2-maxk-inch/8a-outch/4b (arm82) + // dst = 4a-4b-2aa-2bb-maxk-inch/8a-outch/8b (arm82) + // dst = 8a-8b-maxk-inch/8a-outch/8b (arm84) Mat kernel = _kernel.reshape(maxk, inch, outch); #if __ARM_NEON +#if __ARM_FEATURE_DOTPROD + if (outch >= 8) + { + if (inch >= 8) + kernel_tm.create(64 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, outch / 8 + (outch % 8) / 4 + outch % 4, (size_t)1u); + else if (inch >= 4) + kernel_tm.create(32 * maxk, inch / 4 + inch % 4, outch / 8 + (outch % 8) / 4 + outch % 4, (size_t)1u); + else + kernel_tm.create(8 * maxk, inch, outch / 8 + (outch % 8) / 4 + outch % 4, (size_t)1u); + } + else if (outch >= 4) + { + if (inch >= 8) + kernel_tm.create(32 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, outch / 4 + outch % 4, (size_t)1u); + else if (inch >= 4) + kernel_tm.create(16 * maxk, inch / 4 + inch % 4, outch / 4 + outch % 4, (size_t)1u); + else + kernel_tm.create(4 * maxk, inch, outch / 4 + outch % 4, (size_t)1u); + } +#else if (outch >= 4) { if (inch >= 8) @@ -3465,6 +4178,7 @@ static void convolution_im2col_sgemm_transform_kernel_int8_neon(const Mat& _kern else kernel_tm.create(4 * maxk, inch, outch / 4 + outch % 4, (size_t)1u); } +#endif // __ARM_FEATURE_DOTPROD #else // __ARM_NEON if (outch >= 2) { @@ -3489,16 +4203,119 @@ static void convolution_im2col_sgemm_transform_kernel_int8_neon(const Mat& _kern int q = 0; #if __ARM_NEON +#if __ARM_FEATURE_DOTPROD + for (; q + 7 < outch; q += 8) + { + signed char* g00 = kernel_tm.channel(q / 8); + + int p = 0; + for (; p + 7 < inch; p += 8) + { + for (int k = 0; k < maxk; k++) + { +#if __ARM_FEATURE_MATMUL_INT8 + for (int i = 0; i < 8; i++) + { + for (int j = 0; j < 8; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } +#else // __ARM_FEATURE_MATMUL_INT8 + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 4; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } + for (int i = 0; i < 4; i++) + { + for (int j = 4; j < 8; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } + for (int i = 4; i < 8; i++) + { + for (int j = 0; j < 4; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } + for (int i = 4; i < 8; i++) + { + for (int j = 4; j < 8; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } +#endif // __ARM_FEATURE_MATMUL_INT8 + } + } + for (; p + 3 < inch; p += 4) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 8; i++) + { + for (int j = 0; j < 4; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } + } + } + for (; p < inch; p++) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 8; i++) + { + const signed char* k00 = kernel.channel(q + i).row(p); + g00[0] = k00[k]; + g00++; + } + } + } + } +#endif // __ARM_FEATURE_DOTPROD for (; q + 3 < outch; q += 4) { +#if __ARM_FEATURE_DOTPROD + signed char* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4); +#else signed char* g00 = kernel_tm.channel(q / 4); +#endif int p = 0; for (; p + 7 < inch; p += 8) { for (int k = 0; k < maxk; k++) { -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 8; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } +#elif __ARM_FEATURE_DOTPROD for (int i = 0; i < 4; i++) { for (int j = 0; j < 4; j++) @@ -3581,7 +4398,11 @@ static void convolution_im2col_sgemm_transform_kernel_int8_neon(const Mat& _kern for (; q < outch; q++) { #if __ARM_NEON +#if __ARM_FEATURE_DOTPROD + signed char* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4 + q % 4); +#else signed char* g00 = kernel_tm.channel(q / 4 + q % 4); +#endif #else signed char* g00 = kernel_tm.channel(q / 2 + q % 2); #endif diff --git a/src/layer/arm/convolution_sgemm_pack1to4_int8.h b/src/layer/arm/convolution_sgemm_pack1to4_int8.h index d5217a50c..b4babc9d2 100644 --- a/src/layer/arm/convolution_sgemm_pack1to4_int8.h +++ b/src/layer/arm/convolution_sgemm_pack1to4_int8.h @@ -12,19 +12,36 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. +#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 +void im2col_sgemm_pack1to4_int8_neon_i8mm(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt); +void convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon_i8mm(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h); +#endif + #if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __ARM_NEON && __aarch64__ && !__ARM_FEATURE_DOTPROD void im2col_sgemm_pack1to4_int8_neon_asimddp(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt); void convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon_asimddp(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h); #endif +#endif static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt) { +#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + im2col_sgemm_pack1to4_int8_neon_i8mm(bottom_im2col, top_blob, kernel, opt); + return; + } +#endif + #if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __ARM_NEON && __aarch64__ && !__ARM_FEATURE_DOTPROD if (ncnn::cpu_support_arm_asimddp()) { im2col_sgemm_pack1to4_int8_neon_asimddp(bottom_im2col, top_blob, kernel, opt); return; } +#endif #endif // Mat bottom_im2col(size, maxk, inch, 8u, 8, opt.workspace_allocator); @@ -41,9 +58,7 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b #if __ARM_FEATURE_DOTPROD if (inch >= 8) { - if (size >= 16) - tmp.create(16 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, size / 16 + (size % 16) / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator); - else if (size >= 8) + if (size >= 8) tmp.create(8 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, size / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator); else if (size >= 4) tmp.create(4 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, size / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator); @@ -54,9 +69,7 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b } else if (inch >= 4) { - if (size >= 16) - tmp.create(16 * maxk, inch / 4 + inch % 4, size / 16 + (size % 16) / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 4u, 4, opt.workspace_allocator); - else if (size >= 8) + if (size >= 8) tmp.create(8 * maxk, inch / 4 + inch % 4, size / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 4u, 4, opt.workspace_allocator); else if (size >= 4) tmp.create(4 * maxk, inch / 4 + inch % 4, size / 4 + (size % 4) / 2 + size % 2, 4u, 4, opt.workspace_allocator); @@ -67,9 +80,7 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b } else { - if (size >= 16) - tmp.create(16 * maxk, inch, size / 16 + (size % 16) / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 1u, 1, opt.workspace_allocator); - else if (size >= 8) + if (size >= 8) tmp.create(8 * maxk, inch, size / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 1u, 1, opt.workspace_allocator); else if (size >= 4) tmp.create(4 * maxk, inch, size / 4 + (size % 4) / 2 + size % 2, 1u, 1, opt.workspace_allocator); @@ -133,15 +144,15 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b { #if __aarch64__ #if __ARM_FEATURE_DOTPROD - int nn_size = size >> 4; + int nn_size = size >> 3; int remain_size_start = 0; #pragma omp parallel for num_threads(opt.num_threads) for (int ii = 0; ii < nn_size; ii++) { - int i = remain_size_start + ii * 16; + int i = remain_size_start + ii * 8; - signed char* tmpptr = tmp.channel(i / 16); + signed char* tmpptr = tmp.channel(i / 8); int q = 0; for (; q + 7 < inch; q += 8) @@ -157,17 +168,26 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b for (int k = 0; k < maxk; k++) { +#if __ARM_FEATURE_MATMUL_INT8 asm volatile( - "ld1 {v0.16b}, [%0] \n" - "ld1 {v1.16b}, [%1] \n" - "ld1 {v2.16b}, [%2] \n" - "ld1 {v3.16b}, [%3] \n" - "ld1 {v4.16b}, [%4] \n" - "ld1 {v5.16b}, [%5] \n" - "ld1 {v6.16b}, [%6] \n" - "ld1 {v7.16b}, [%7] \n" - "st4 {v0.16b, v1.16b, v2.16b, v3.16b}, [%8], #64 \n" - "st4 {v4.16b, v5.16b, v6.16b, v7.16b}, [%8], #64 \n" + "ld1 {v0.8b}, [%0] \n" + "ld1 {v1.8b}, [%1] \n" + "ld1 {v2.8b}, [%2] \n" + "ld1 {v3.8b}, [%3] \n" + "ld1 {v4.8b}, [%4] \n" + "ld1 {v5.8b}, [%5] \n" + "ld1 {v6.8b}, [%6] \n" + "ld1 {v7.8b}, [%7] \n" + "zip1 v8.8b, v0.8b, v4.8b \n" + "zip1 v9.8b, v1.8b, v5.8b \n" + "zip1 v10.8b, v2.8b, v6.8b \n" + "zip1 v11.8b, v3.8b, v7.8b \n" + "zip2 v0.8b, v0.8b, v4.8b \n" + "zip2 v1.8b, v1.8b, v5.8b \n" + "zip2 v2.8b, v2.8b, v6.8b \n" + "zip2 v3.8b, v3.8b, v7.8b \n" + "st4 {v8.8b, v9.8b, v10.8b, v11.8b}, [%8], #32 \n" + "st4 {v0.8b, v1.8b, v2.8b, v3.8b}, [%8], #32 \n" : "=r"(img0), // %0 "=r"(img1), "=r"(img2), @@ -186,93 +206,8 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b "6"(img6), "7"(img7), "8"(tmpptr) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); - img0 += size; - img1 += size; - img2 += size; - img3 += size; - img4 += size; - img5 += size; - img6 += size; - img7 += size; - } - } - for (; q + 3 < inch; q += 4) - { - const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; - const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i; - const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i; - const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i; - - for (int k = 0; k < maxk; k++) - { - asm volatile( - "ld1 {v0.16b}, [%0] \n" - "ld1 {v1.16b}, [%1] \n" - "ld1 {v2.16b}, [%2] \n" - "ld1 {v3.16b}, [%3] \n" - "st4 {v0.16b, v1.16b, v2.16b, v3.16b}, [%4], #64 \n" - : "=r"(img0), // %0 - "=r"(img1), - "=r"(img2), - "=r"(img3), - "=r"(tmpptr) // %4 - : "0"(img0), - "1"(img1), - "2"(img2), - "3"(img3), - "4"(tmpptr) - : "memory", "v0", "v1", "v2", "v3"); - img0 += size; - img1 += size; - img2 += size; - img3 += size; - } - } - for (; q < inch; q++) - { - const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; - - for (int k = 0; k < maxk; k++) - { - asm volatile( - "prfm pldl1keep, [%0, #128] \n" - "ld1 {v0.16b}, [%0] \n" - "st1 {v0.16b}, [%1], #16 \n" - : "=r"(img0), // %0 - "=r"(tmpptr) // %1 - : "0"(img0), - "1"(tmpptr) - : "memory", "v0"); - img0 += size; - } - } - } - - remain_size_start += nn_size << 4; - nn_size = (size - remain_size_start) >> 3; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int ii = 0; ii < nn_size; ii++) - { - int i = remain_size_start + ii * 8; - - signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8); - - int q = 0; - for (; q + 7 < inch; q += 8) - { - const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; - const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i; - const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i; - const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i; - const signed char* img4 = (const signed char*)bottom_im2col.channel(q + 4) + i; - const signed char* img5 = (const signed char*)bottom_im2col.channel(q + 5) + i; - const signed char* img6 = (const signed char*)bottom_im2col.channel(q + 6) + i; - const signed char* img7 = (const signed char*)bottom_im2col.channel(q + 7) + i; - - for (int k = 0; k < maxk; k++) - { + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11"); +#else // __ARM_FEATURE_MATMUL_INT8 asm volatile( "ld1 {v0.8b}, [%0] \n" "ld1 {v1.8b}, [%1] \n" @@ -303,6 +238,7 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b "7"(img7), "8"(tmpptr) : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); +#endif // __ARM_FEATURE_MATMUL_INT8 img0 += size; img1 += size; img2 += size; @@ -378,7 +314,7 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b int i = remain_size_start + ii * 4; #if __ARM_FEATURE_DOTPROD - signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4); + signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4); #else signed char* tmpptr = tmp.channel(i / 4); #endif @@ -397,7 +333,47 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b for (int k = 0; k < maxk; k++) { -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + tmpptr[0] = img0[0]; + tmpptr[1] = img1[0]; + tmpptr[2] = img2[0]; + tmpptr[3] = img3[0]; + tmpptr[4] = img4[0]; + tmpptr[5] = img5[0]; + tmpptr[6] = img6[0]; + tmpptr[7] = img7[0]; + tmpptr += 8; + + tmpptr[0] = img0[1]; + tmpptr[1] = img1[1]; + tmpptr[2] = img2[1]; + tmpptr[3] = img3[1]; + tmpptr[4] = img4[1]; + tmpptr[5] = img5[1]; + tmpptr[6] = img6[1]; + tmpptr[7] = img7[1]; + tmpptr += 8; + + tmpptr[0] = img0[2]; + tmpptr[1] = img1[2]; + tmpptr[2] = img2[2]; + tmpptr[3] = img3[2]; + tmpptr[4] = img4[2]; + tmpptr[5] = img5[2]; + tmpptr[6] = img6[2]; + tmpptr[7] = img7[2]; + tmpptr += 8; + + tmpptr[0] = img0[3]; + tmpptr[1] = img1[3]; + tmpptr[2] = img2[3]; + tmpptr[3] = img3[3]; + tmpptr[4] = img4[3]; + tmpptr[5] = img5[3]; + tmpptr[6] = img6[3]; + tmpptr[7] = img7[3]; + tmpptr += 8; +#elif __ARM_FEATURE_DOTPROD tmpptr[0] = img0[0]; tmpptr[1] = img1[0]; tmpptr[2] = img2[0]; @@ -477,7 +453,7 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b tmpptr[6] = img6[3]; tmpptr[7] = img7[3]; tmpptr += 8; -#endif // __ARM_FEATURE_DOTPROD +#endif img0 += size; img1 += size; @@ -556,7 +532,7 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b #if __aarch64__ #if __ARM_FEATURE_DOTPROD - signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2); + signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2); #else signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); #endif @@ -578,7 +554,27 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b for (int k = 0; k < maxk; k++) { -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + tmpptr[0] = img0[0]; + tmpptr[1] = img1[0]; + tmpptr[2] = img2[0]; + tmpptr[3] = img3[0]; + tmpptr[4] = img4[0]; + tmpptr[5] = img5[0]; + tmpptr[6] = img6[0]; + tmpptr[7] = img7[0]; + tmpptr += 8; + + tmpptr[0] = img0[1]; + tmpptr[1] = img1[1]; + tmpptr[2] = img2[1]; + tmpptr[3] = img3[1]; + tmpptr[4] = img4[1]; + tmpptr[5] = img5[1]; + tmpptr[6] = img6[1]; + tmpptr[7] = img7[1]; + tmpptr += 8; +#elif __ARM_FEATURE_DOTPROD tmpptr[0] = img0[0]; tmpptr[1] = img1[0]; tmpptr[2] = img2[0]; @@ -618,7 +614,7 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b tmpptr[6] = img6[1]; tmpptr[7] = img7[1]; tmpptr += 8; -#endif // __ARM_FEATURE_DOTPROD +#endif img0 += size; img1 += size; @@ -678,7 +674,7 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b { #if __aarch64__ #if __ARM_FEATURE_DOTPROD - signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); #else signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); #endif @@ -757,18 +753,23 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b } } +#if __ARM_FEATURE_DOTPROD + int nn_outch = outch / 2; + int remain_outch_start = nn_outch * 2; + #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) + for (int pp = 0; pp < nn_outch; pp++) { + int p = pp * 2; + int* outptr0 = top_blob.channel(p); + int* outptr1 = top_blob.channel(p + 1); int i = 0; -#if __aarch64__ -#if __ARM_FEATURE_DOTPROD - for (; i + 15 < size; i += 16) + for (; i + 7 < size; i += 8) { - const signed char* tmpptr = tmp.channel(i / 16); - const signed char* kptr0 = kernel.channel(p); + const signed char* tmpptr = tmp.channel(i / 8); + const signed char* kptr0 = kernel.channel(p / 2); int nn = (inch / 8) * maxk; int nn4 = ((inch % 8) / 4) * maxk; @@ -792,231 +793,742 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b "eor v30.16b, v30.16b, v30.16b \n" "eor v31.16b, v31.16b, v31.16b \n" - "cmp %w1, #0 \n" + "cmp %w2, #0 \n" "beq 1f \n" - "ld1 {v8.16b}, [%5], #16 \n" // _w0123_l +#if __ARM_FEATURE_MATMUL_INT8 + "eor v4.16b, v4.16b, v4.16b \n" + "eor v5.16b, v5.16b, v5.16b \n" + "eor v6.16b, v6.16b, v6.16b \n" + "eor v7.16b, v7.16b, v7.16b \n" + "eor v12.16b, v12.16b, v12.16b \n" + "eor v13.16b, v13.16b, v13.16b \n" + "eor v14.16b, v14.16b, v14.16b \n" + "eor v15.16b, v15.16b, v15.16b \n" + + "0: \n" + + "ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [%5], #64 \n" // _val0 _val1 _val2 _val3 + "ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [%6], #64 \n" // _w01 _w23 _w45 _w67 - "ld1 {v0.16b}, [%4], #16 \n" // _val0123_l + "smmla v4.4s, v0.16b, v8.16b \n" + "smmla v17.4s, v0.16b, v9.16b \n" + "smmla v5.4s, v1.16b, v8.16b \n" + "smmla v19.4s, v1.16b, v9.16b \n" + "smmla v6.4s, v2.16b, v8.16b \n" + "smmla v21.4s, v2.16b, v9.16b \n" + "smmla v7.4s, v3.16b, v8.16b \n" + "smmla v23.4s, v3.16b, v9.16b \n" + "subs %w2, %w2, #1 \n" + + "smmla v12.4s, v0.16b, v10.16b \n" + "smmla v25.4s, v0.16b, v11.16b \n" + "smmla v13.4s, v1.16b, v10.16b \n" + "smmla v27.4s, v1.16b, v11.16b \n" + "smmla v14.4s, v2.16b, v10.16b \n" + "smmla v29.4s, v2.16b, v11.16b \n" + "smmla v15.4s, v3.16b, v10.16b \n" + "smmla v31.4s, v3.16b, v11.16b \n" + + "bne 0b \n" + + "trn1 v16.2d, v4.2d, v17.2d \n" + "trn2 v17.2d, v4.2d, v17.2d \n" + "trn1 v18.2d, v5.2d, v19.2d \n" + "trn2 v19.2d, v5.2d, v19.2d \n" + "trn1 v20.2d, v6.2d, v21.2d \n" + "trn2 v21.2d, v6.2d, v21.2d \n" + "trn1 v22.2d, v7.2d, v23.2d \n" + "trn2 v23.2d, v7.2d, v23.2d \n" + + "trn1 v24.2d, v12.2d, v25.2d \n" + "trn2 v25.2d, v12.2d, v25.2d \n" + "trn1 v26.2d, v13.2d, v27.2d \n" + "trn2 v27.2d, v13.2d, v27.2d \n" + "trn1 v28.2d, v14.2d, v29.2d \n" + "trn2 v29.2d, v14.2d, v29.2d \n" + "trn1 v30.2d, v15.2d, v31.2d \n" + "trn2 v31.2d, v15.2d, v31.2d \n" +#else // __ARM_FEATURE_MATMUL_INT8 "0: \n" - "ld1 {v1.16b}, [%4], #16 \n" // _val4567_l + "ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [%5], #64 \n" // _val0123_l _val4567_l _val0123_h _val4567_h + "ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [%6], #64 \n" // _w0123_l _w0123_h _w4567_l _w4567_h "sdot v16.4s, v8.16b, v0.4b[0] \n" "sdot v17.4s, v8.16b, v0.4b[1] \n" "sdot v18.4s, v8.16b, v0.4b[2] \n" "sdot v19.4s, v8.16b, v0.4b[3] \n" + "sdot v20.4s, v8.16b, v1.4b[0] \n" + "sdot v21.4s, v8.16b, v1.4b[1] \n" + "sdot v22.4s, v8.16b, v1.4b[2] \n" + "sdot v23.4s, v8.16b, v1.4b[3] \n" + + "sdot v16.4s, v9.16b, v2.4b[0] \n" + "sdot v17.4s, v9.16b, v2.4b[1] \n" + "sdot v18.4s, v9.16b, v2.4b[2] \n" + "sdot v19.4s, v9.16b, v2.4b[3] \n" + "sdot v20.4s, v9.16b, v3.4b[0] \n" + "sdot v21.4s, v9.16b, v3.4b[1] \n" + "sdot v22.4s, v9.16b, v3.4b[2] \n" + "sdot v23.4s, v9.16b, v3.4b[3] \n" + + "subs %w2, %w2, #1 \n" + + "sdot v24.4s, v10.16b, v0.4b[0] \n" + "sdot v25.4s, v10.16b, v0.4b[1] \n" + "sdot v26.4s, v10.16b, v0.4b[2] \n" + "sdot v27.4s, v10.16b, v0.4b[3] \n" + "sdot v28.4s, v10.16b, v1.4b[0] \n" + "sdot v29.4s, v10.16b, v1.4b[1] \n" + "sdot v30.4s, v10.16b, v1.4b[2] \n" + "sdot v31.4s, v10.16b, v1.4b[3] \n" + + "sdot v24.4s, v11.16b, v2.4b[0] \n" + "sdot v25.4s, v11.16b, v2.4b[1] \n" + "sdot v26.4s, v11.16b, v2.4b[2] \n" + "sdot v27.4s, v11.16b, v2.4b[3] \n" + "sdot v28.4s, v11.16b, v3.4b[0] \n" + "sdot v29.4s, v11.16b, v3.4b[1] \n" + "sdot v30.4s, v11.16b, v3.4b[2] \n" + "sdot v31.4s, v11.16b, v3.4b[3] \n" - "ld1 {v2.16b}, [%4], #16 \n" // _val891011_l + "bne 0b \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + "1: \n" + + "cmp %w3, #0 \n" + "beq 3f \n" + + "2: \n" + + "ld1 {v0.16b, v1.16b}, [%5], #32 \n" // _val0123 _val4567 + "ld1 {v8.16b, v9.16b}, [%6], #32 \n" // _w0 _w1 + "sdot v16.4s, v8.16b, v0.4b[0] \n" + "sdot v17.4s, v8.16b, v0.4b[1] \n" + "sdot v18.4s, v8.16b, v0.4b[2] \n" + "sdot v19.4s, v8.16b, v0.4b[3] \n" "sdot v20.4s, v8.16b, v1.4b[0] \n" "sdot v21.4s, v8.16b, v1.4b[1] \n" "sdot v22.4s, v8.16b, v1.4b[2] \n" "sdot v23.4s, v8.16b, v1.4b[3] \n" - "ld1 {v3.16b}, [%4], #16 \n" // _val12131415_l + "subs %w3, %w3, #1 \n" - "sdot v24.4s, v8.16b, v2.4b[0] \n" - "sdot v25.4s, v8.16b, v2.4b[1] \n" + "sdot v24.4s, v9.16b, v0.4b[0] \n" + "sdot v25.4s, v9.16b, v0.4b[1] \n" + "sdot v26.4s, v9.16b, v0.4b[2] \n" + "sdot v27.4s, v9.16b, v0.4b[3] \n" + "sdot v28.4s, v9.16b, v1.4b[0] \n" + "sdot v29.4s, v9.16b, v1.4b[1] \n" + "sdot v30.4s, v9.16b, v1.4b[2] \n" + "sdot v31.4s, v9.16b, v1.4b[3] \n" - "ld1 {v9.16b}, [%5], #16 \n" // _w0123_h + "bne 2b \n" - "sdot v26.4s, v8.16b, v2.4b[2] \n" - "sdot v27.4s, v8.16b, v2.4b[3] \n" + "3: \n" - "ld1 {v4.16b}, [%4], #16 \n" // _val0123_h + "lsr w4, %w4, #2 \n" // w4 = nn1 >> 2 + "cmp w4, #0 \n" + "beq 5f \n" - "sdot v28.4s, v8.16b, v3.4b[0] \n" - "sdot v29.4s, v8.16b, v3.4b[1] \n" - "sdot v30.4s, v8.16b, v3.4b[2] \n" - "sdot v31.4s, v8.16b, v3.4b[3] \n" + "4: \n" - "ld1 {v5.16b}, [%4], #16 \n" // _val4567_h + "ld2 {v0.4s, v1.4s}, [%5], #32 \n" + "ld2 {v8.4s, v9.4s}, [%6], #32 \n" + + "uzp1 v2.16b, v0.16b, v1.16b \n" + "uzp2 v3.16b, v0.16b, v1.16b \n" + "uzp1 v0.16b, v2.16b, v3.16b \n" + "uzp2 v1.16b, v2.16b, v3.16b \n" + "uzp1 v2.4s, v0.4s, v1.4s \n" // _val0123 + "uzp2 v3.4s, v0.4s, v1.4s \n" // _val4567 + + "uzp1 v10.16b, v8.16b, v9.16b \n" + "uzp2 v11.16b, v8.16b, v9.16b \n" + "uzp1 v8.16b, v10.16b, v11.16b \n" + "uzp2 v9.16b, v10.16b, v11.16b \n" + "uzp1 v10.4s, v8.4s, v9.4s \n" // _w0123f + "uzp2 v11.4s, v8.4s, v9.4s \n" // _w4567f + + "sdot v16.4s, v10.16b, v2.4b[0] \n" + "sdot v17.4s, v10.16b, v2.4b[1] \n" + "sdot v18.4s, v10.16b, v2.4b[2] \n" + "sdot v19.4s, v10.16b, v2.4b[3] \n" + "sdot v20.4s, v10.16b, v3.4b[0] \n" + "sdot v21.4s, v10.16b, v3.4b[1] \n" + "sdot v22.4s, v10.16b, v3.4b[2] \n" + "sdot v23.4s, v10.16b, v3.4b[3] \n" - "sdot v16.4s, v9.16b, v4.4b[0] \n" - "sdot v17.4s, v9.16b, v4.4b[1] \n" - "sdot v18.4s, v9.16b, v4.4b[2] \n" - "sdot v19.4s, v9.16b, v4.4b[3] \n" + "subs w4, w4, #1 \n" - "ld1 {v6.16b}, [%4], #16 \n" // _val891011_h + "sdot v24.4s, v11.16b, v2.4b[0] \n" + "sdot v25.4s, v11.16b, v2.4b[1] \n" + "sdot v26.4s, v11.16b, v2.4b[2] \n" + "sdot v27.4s, v11.16b, v2.4b[3] \n" + "sdot v28.4s, v11.16b, v3.4b[0] \n" + "sdot v29.4s, v11.16b, v3.4b[1] \n" + "sdot v30.4s, v11.16b, v3.4b[2] \n" + "sdot v31.4s, v11.16b, v3.4b[3] \n" - "sdot v20.4s, v9.16b, v5.4b[0] \n" - "sdot v21.4s, v9.16b, v5.4b[1] \n" - "sdot v22.4s, v9.16b, v5.4b[2] \n" - "sdot v23.4s, v9.16b, v5.4b[3] \n" + "bne 4b \n" - "ld1 {v7.16b}, [%4], #16 \n" // _val12131415_h + "5: \n" - "sdot v24.4s, v9.16b, v6.4b[0] \n" - "sdot v25.4s, v9.16b, v6.4b[1] \n" + "and w4, %w4, #3 \n" // w4 = remain = nn1 & 3 + "cmp w4, #0 \n" // w4 > 0 + "beq 7f \n" - "ld1 {v8.16b}, [%5], #16 \n" // _w0123_l + "6: \n" - "sdot v26.4s, v9.16b, v6.4b[2] \n" - "sdot v27.4s, v9.16b, v6.4b[3] \n" + "ld1 {v0.8b}, [%5], #8 \n" + "ld1 {v1.8b}, [%6], #8 \n" - "ld1 {v0.16b}, [%4], #16 \n" // _val0123_l + "sshll v0.8h, v0.8b, #0 \n" - "sdot v28.4s, v9.16b, v7.4b[0] \n" - "sdot v29.4s, v9.16b, v7.4b[1] \n" + "sshll v1.8h, v1.8b, #0 \n" - "subs %w1, %w1, #1 \n" + "smlal v16.4s, v1.4h, v0.h[0] \n" + "smlal v17.4s, v1.4h, v0.h[1] \n" + "smlal v18.4s, v1.4h, v0.h[2] \n" + "smlal v19.4s, v1.4h, v0.h[3] \n" + "smlal v20.4s, v1.4h, v0.h[4] \n" + "smlal v21.4s, v1.4h, v0.h[5] \n" + "smlal v22.4s, v1.4h, v0.h[6] \n" + "smlal v23.4s, v1.4h, v0.h[7] \n" - "sdot v30.4s, v9.16b, v7.4b[2] \n" - "sdot v31.4s, v9.16b, v7.4b[3] \n" + "subs w4, w4, #1 \n" - "bne 0b \n" + "smlal2 v24.4s, v1.8h, v0.h[0] \n" + "smlal2 v25.4s, v1.8h, v0.h[1] \n" + "smlal2 v26.4s, v1.8h, v0.h[2] \n" + "smlal2 v27.4s, v1.8h, v0.h[3] \n" + "smlal2 v28.4s, v1.8h, v0.h[4] \n" + "smlal2 v29.4s, v1.8h, v0.h[5] \n" + "smlal2 v30.4s, v1.8h, v0.h[6] \n" + "smlal2 v31.4s, v1.8h, v0.h[7] \n" - "sub %4, %4, #16 \n" - "sub %5, %5, #16 \n" + "bne 6b \n" - "1: \n" + "7: \n" - "cmp %w2, #0 \n" - "beq 3f \n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%1], #64 \n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%1], #64 \n" + : "=r"(outptr0), + "=r"(outptr1), + "=r"(nn), + "=r"(nn4), + "=r"(nn1), + "=r"(tmpptr), + "=r"(kptr0) + : "0"(outptr0), + "1"(outptr1), + "2"(nn), + "3"(nn4), + "4"(nn1), + "5"(tmpptr), + "6"(kptr0) + : "memory", "x4", "x5", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + for (; i + 3 < size; i += 4) + { + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4); + const signed char* kptr0 = kernel.channel(p / 2); - "2: \n" + int nn = (inch / 8) * maxk; + int nn4 = ((inch % 8) / 4) * maxk; + int nn1 = (inch % 4) * maxk; - "ld1 {v8.16b}, [%5], #16 \n" + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + int32x4_t _sum4 = vdupq_n_s32(0); + int32x4_t _sum5 = vdupq_n_s32(0); + int32x4_t _sum6 = vdupq_n_s32(0); + int32x4_t _sum7 = vdupq_n_s32(0); - "ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [%4], #64 \n" +#if __ARM_FEATURE_MATMUL_INT8 + for (int j = 0; j < nn; j++) + { + int8x16_t _val0 = vld1q_s8(tmpptr); + int8x16_t _val1 = vld1q_s8(tmpptr + 16); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + int8x16_t _w45 = vld1q_s8(kptr0 + 32); + int8x16_t _w67 = vld1q_s8(kptr0 + 48); + + _sum0 = vmmlaq_s32(_sum0, _val0, _w01); + _sum1 = vmmlaq_s32(_sum1, _val0, _w23); + _sum2 = vmmlaq_s32(_sum2, _val1, _w01); + _sum3 = vmmlaq_s32(_sum3, _val1, _w23); + + _sum4 = vmmlaq_s32(_sum4, _val0, _w45); + _sum5 = vmmlaq_s32(_sum5, _val0, _w67); + _sum6 = vmmlaq_s32(_sum6, _val1, _w45); + _sum7 = vmmlaq_s32(_sum7, _val1, _w67); - "sdot v16.4s, v8.16b, v0.4b[0] \n" - "sdot v17.4s, v8.16b, v0.4b[1] \n" - "sdot v18.4s, v8.16b, v0.4b[2] \n" - "sdot v19.4s, v8.16b, v0.4b[3] \n" - "sdot v20.4s, v8.16b, v1.4b[0] \n" - "sdot v21.4s, v8.16b, v1.4b[1] \n" - "sdot v22.4s, v8.16b, v1.4b[2] \n" - "sdot v23.4s, v8.16b, v1.4b[3] \n" - "sdot v24.4s, v8.16b, v2.4b[0] \n" - "sdot v25.4s, v8.16b, v2.4b[1] \n" - "sdot v26.4s, v8.16b, v2.4b[2] \n" - "sdot v27.4s, v8.16b, v2.4b[3] \n" - "sdot v28.4s, v8.16b, v3.4b[0] \n" - "sdot v29.4s, v8.16b, v3.4b[1] \n" + tmpptr += 32; + kptr0 += 64; + } - "subs %w2, %w2, #1 \n" + int32x4_t _sum0x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum0), vreinterpretq_s64_s32(_sum1))); + int32x4_t _sum1x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum0), vreinterpretq_s64_s32(_sum1))); + int32x4_t _sum2x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum2), vreinterpretq_s64_s32(_sum3))); + int32x4_t _sum3x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum2), vreinterpretq_s64_s32(_sum3))); + int32x4_t _sum4x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum4), vreinterpretq_s64_s32(_sum5))); + int32x4_t _sum5x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum4), vreinterpretq_s64_s32(_sum5))); + int32x4_t _sum6x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum6), vreinterpretq_s64_s32(_sum7))); + int32x4_t _sum7x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum6), vreinterpretq_s64_s32(_sum7))); + + _sum0 = _sum0x; + _sum1 = _sum1x; + _sum2 = _sum2x; + _sum3 = _sum3x; + _sum4 = _sum4x; + _sum5 = _sum5x; + _sum6 = _sum6x; + _sum7 = _sum7x; +#else // __ARM_FEATURE_MATMUL_INT8 + for (int j = 0; j < nn; j++) + { + int8x16_t _val0123_l = vld1q_s8(tmpptr); + int8x16_t _val0123_h = vld1q_s8(tmpptr + 16); + int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); + int8x16_t _w4567_l = vld1q_s8(kptr0 + 32); + int8x16_t _w4567_h = vld1q_s8(kptr0 + 48); - "sdot v30.4s, v8.16b, v3.4b[2] \n" - "sdot v31.4s, v8.16b, v3.4b[3] \n" + _sum0 = vdotq_laneq_s32(_sum0, _w0123_l, _val0123_l, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0123_l, _val0123_l, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w0123_l, _val0123_l, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w0123_l, _val0123_l, 3); + _sum0 = vdotq_laneq_s32(_sum0, _w0123_h, _val0123_h, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0123_h, _val0123_h, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w0123_h, _val0123_h, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w0123_h, _val0123_h, 3); - "bne 2b \n" + _sum4 = vdotq_laneq_s32(_sum4, _w4567_l, _val0123_l, 0); + _sum5 = vdotq_laneq_s32(_sum5, _w4567_l, _val0123_l, 1); + _sum6 = vdotq_laneq_s32(_sum6, _w4567_l, _val0123_l, 2); + _sum7 = vdotq_laneq_s32(_sum7, _w4567_l, _val0123_l, 3); + _sum4 = vdotq_laneq_s32(_sum4, _w4567_h, _val0123_h, 0); + _sum5 = vdotq_laneq_s32(_sum5, _w4567_h, _val0123_h, 1); + _sum6 = vdotq_laneq_s32(_sum6, _w4567_h, _val0123_h, 2); + _sum7 = vdotq_laneq_s32(_sum7, _w4567_h, _val0123_h, 3); - "3: \n" + tmpptr += 32; + kptr0 += 64; + } +#endif // __ARM_FEATURE_MATMUL_INT8 - "lsr w4, %w3, #2 \n" // w4 = nn1 >> 2 - "cmp w4, #0 \n" - "beq 5f \n" + for (int j = 0; j < nn4; j++) + { + int8x16_t _val0123 = vld1q_s8(tmpptr); + int8x16_t _w0 = vld1q_s8(kptr0); + int8x16_t _w1 = vld1q_s8(kptr0 + 16); - "4: \n" + _sum0 = vdotq_laneq_s32(_sum0, _w0, _val0123, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0, _val0123, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w0, _val0123, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w0, _val0123, 3); + + _sum4 = vdotq_laneq_s32(_sum4, _w1, _val0123, 0); + _sum5 = vdotq_laneq_s32(_sum5, _w1, _val0123, 1); + _sum6 = vdotq_laneq_s32(_sum6, _w1, _val0123, 2); + _sum7 = vdotq_laneq_s32(_sum7, _w1, _val0123, 3); + + tmpptr += 16; + kptr0 += 32; + } + + int j = 0; + for (; j + 3 < nn1; j += 4) + { + // 0123 0123 0123 0123 -> 0000111122223333 + int8x16_t _val = vld1q_s8(tmpptr); + + int8x8x2_t _val01 = vuzp_s8(vget_low_s8(_val), vget_high_s8(_val)); + int8x8x2_t _val0123 = vuzp_s8(_val01.val[0], _val01.val[1]); + int8x16_t _val0123f = vcombine_s8(_val0123.val[0], _val0123.val[1]); + + // 0123 4567 0123 4567 0123 4567 0123 4567 -> 0000111122223333 + int32x4x2_t _w = vld2q_s32((const int*)kptr0); + + int8x16_t _w0 = vreinterpretq_s8_s32(_w.val[0]); + int8x16_t _w1 = vreinterpretq_s8_s32(_w.val[1]); + + int8x8x2_t _w01 = vuzp_s8(vget_low_s8(_w0), vget_high_s8(_w0)); + int8x8x2_t _w0123 = vuzp_s8(_w01.val[0], _w01.val[1]); + int8x16_t _w0123f = vcombine_s8(_w0123.val[0], _w0123.val[1]); + + int8x8x2_t _w45 = vuzp_s8(vget_low_s8(_w1), vget_high_s8(_w1)); + int8x8x2_t _w4567 = vuzp_s8(_w45.val[0], _w45.val[1]); + int8x16_t _w4567f = vcombine_s8(_w4567.val[0], _w4567.val[1]); + + _sum0 = vdotq_laneq_s32(_sum0, _w0123f, _val0123f, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0123f, _val0123f, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w0123f, _val0123f, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w0123f, _val0123f, 3); + + _sum4 = vdotq_laneq_s32(_sum4, _w4567f, _val0123f, 0); + _sum5 = vdotq_laneq_s32(_sum5, _w4567f, _val0123f, 1); + _sum6 = vdotq_laneq_s32(_sum6, _w4567f, _val0123f, 2); + _sum7 = vdotq_laneq_s32(_sum7, _w4567f, _val0123f, 3); + + tmpptr += 16; + kptr0 += 32; + } + for (; j < nn1; j++) + { + int16x4_t _val0 = vdup_n_s16(tmpptr[0]); + int16x4_t _val1 = vdup_n_s16(tmpptr[1]); + int16x4_t _val2 = vdup_n_s16(tmpptr[2]); + int16x4_t _val3 = vdup_n_s16(tmpptr[3]); + + int16x8_t _w01 = vmovl_s8(vld1_s8(kptr0)); + + _sum0 = vmlal_s16(_sum0, _val0, vget_low_s16(_w01)); + _sum1 = vmlal_s16(_sum1, _val1, vget_low_s16(_w01)); + _sum2 = vmlal_s16(_sum2, _val2, vget_low_s16(_w01)); + _sum3 = vmlal_s16(_sum3, _val3, vget_low_s16(_w01)); - "ld1 {v8.8b, v9.8b}, [%5], #16 \n" + _sum4 = vmlal_s16(_sum4, _val0, vget_high_s16(_w01)); + _sum5 = vmlal_s16(_sum5, _val1, vget_high_s16(_w01)); + _sum6 = vmlal_s16(_sum6, _val2, vget_high_s16(_w01)); + _sum7 = vmlal_s16(_sum7, _val3, vget_high_s16(_w01)); - "ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [%4], #64 \n" + tmpptr += 4; + kptr0 += 8; + } + + vst1q_s32(outptr0, _sum0); + vst1q_s32(outptr0 + 4, _sum1); + vst1q_s32(outptr0 + 8, _sum2); + vst1q_s32(outptr0 + 12, _sum3); + vst1q_s32(outptr1, _sum4); + vst1q_s32(outptr1 + 4, _sum5); + vst1q_s32(outptr1 + 8, _sum6); + vst1q_s32(outptr1 + 12, _sum7); + outptr0 += 16; + outptr1 += 16; + } + for (; i + 1 < size; i += 2) + { + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2); + const signed char* kptr0 = kernel.channel(p / 2); + + int nn = (inch / 8) * maxk; + int nn4 = ((inch % 8) / 4) * maxk; + int nn1 = (inch % 4) * maxk; + + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + +#if __ARM_FEATURE_MATMUL_INT8 + for (int j = 0; j < nn; j++) + { + int8x16_t _val = vld1q_s8(tmpptr); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + int8x16_t _w45 = vld1q_s8(kptr0 + 32); + int8x16_t _w67 = vld1q_s8(kptr0 + 48); + + _sum0 = vmmlaq_s32(_sum0, _val, _w01); + _sum1 = vmmlaq_s32(_sum1, _val, _w23); + _sum2 = vmmlaq_s32(_sum2, _val, _w45); + _sum3 = vmmlaq_s32(_sum3, _val, _w67); + + tmpptr += 16; + kptr0 += 64; + } + + int32x4_t _sum0x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum0), vreinterpretq_s64_s32(_sum1))); + int32x4_t _sum1x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum0), vreinterpretq_s64_s32(_sum1))); + int32x4_t _sum2x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum2), vreinterpretq_s64_s32(_sum3))); + int32x4_t _sum3x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum2), vreinterpretq_s64_s32(_sum3))); + + _sum0 = _sum0x; + _sum1 = _sum1x; + _sum2 = _sum2x; + _sum3 = _sum3x; +#else // __ARM_FEATURE_MATMUL_INT8 + for (int j = 0; j < nn; j++) + { + int8x16_t _val01_l_h = vld1q_s8(tmpptr); + int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); + int8x16_t _w4567_l = vld1q_s8(kptr0 + 32); + int8x16_t _w4567_h = vld1q_s8(kptr0 + 48); + + _sum0 = vdotq_laneq_s32(_sum0, _w0123_l, _val01_l_h, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0123_l, _val01_l_h, 1); + _sum0 = vdotq_laneq_s32(_sum0, _w0123_h, _val01_l_h, 2); + _sum1 = vdotq_laneq_s32(_sum1, _w0123_h, _val01_l_h, 3); + + _sum2 = vdotq_laneq_s32(_sum2, _w4567_l, _val01_l_h, 0); + _sum3 = vdotq_laneq_s32(_sum3, _w4567_l, _val01_l_h, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w4567_h, _val01_l_h, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w4567_h, _val01_l_h, 3); + + tmpptr += 16; + kptr0 += 64; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + + if (nn4 > 0) + { + int j = 0; + for (; j + 1 < nn4; j += 2) + { + int8x16_t _val0123 = vld1q_s8(tmpptr); + int8x16_t _w0 = vld1q_s8(kptr0); + int8x16_t _w1 = vld1q_s8(kptr0 + 16); + int8x16_t _w2 = vld1q_s8(kptr0 + 32); + int8x16_t _w3 = vld1q_s8(kptr0 + 48); + + _sum0 = vdotq_laneq_s32(_sum0, _w0, _val0123, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0, _val0123, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w1, _val0123, 0); + _sum3 = vdotq_laneq_s32(_sum3, _w1, _val0123, 1); + + _sum0 = vdotq_laneq_s32(_sum0, _w2, _val0123, 2); + _sum1 = vdotq_laneq_s32(_sum1, _w2, _val0123, 3); + _sum2 = vdotq_laneq_s32(_sum2, _w3, _val0123, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w3, _val0123, 3); + + tmpptr += 16; + kptr0 += 64; + } + for (; j < nn4; j++) + { + int8x8_t _val01 = vld1_s8(tmpptr); + int8x16_t _w0 = vld1q_s8(kptr0); + int8x16_t _w1 = vld1q_s8(kptr0 + 16); + + _sum0 = vdotq_lane_s32(_sum0, _w0, _val01, 0); + _sum1 = vdotq_lane_s32(_sum1, _w0, _val01, 1); + _sum2 = vdotq_lane_s32(_sum2, _w1, _val01, 0); + _sum3 = vdotq_lane_s32(_sum3, _w1, _val01, 1); + + tmpptr += 8; + kptr0 += 32; + } + } + + int j = 0; + for (; j + 3 < nn1; j += 4) + { + int16x8_t _val01234567 = vmovl_s8(vld1_s8(tmpptr)); + + int8x16_t _w0 = vld1q_s8(kptr0); + int8x16_t _w1 = vld1q_s8(kptr0 + 16); + int16x8_t _w0l = vmovl_s8(vget_low_s8(_w0)); + int16x8_t _w0h = vmovl_s8(vget_high_s8(_w0)); + int16x8_t _w1l = vmovl_s8(vget_low_s8(_w1)); + int16x8_t _w1h = vmovl_s8(vget_high_s8(_w1)); + + _sum0 = vmlal_laneq_s16(_sum0, vget_low_s16(_w0l), _val01234567, 0); + _sum1 = vmlal_laneq_s16(_sum1, vget_low_s16(_w0l), _val01234567, 1); + _sum2 = vmlal_laneq_s16(_sum2, vget_high_s16(_w0l), _val01234567, 0); + _sum3 = vmlal_laneq_s16(_sum3, vget_high_s16(_w0l), _val01234567, 1); + + _sum0 = vmlal_laneq_s16(_sum0, vget_low_s16(_w0h), _val01234567, 2); + _sum1 = vmlal_laneq_s16(_sum1, vget_low_s16(_w0h), _val01234567, 3); + _sum2 = vmlal_laneq_s16(_sum2, vget_high_s16(_w0h), _val01234567, 2); + _sum3 = vmlal_laneq_s16(_sum3, vget_high_s16(_w0h), _val01234567, 3); + + _sum0 = vmlal_laneq_s16(_sum0, vget_low_s16(_w1l), _val01234567, 4); + _sum1 = vmlal_laneq_s16(_sum1, vget_low_s16(_w1l), _val01234567, 5); + _sum2 = vmlal_laneq_s16(_sum2, vget_high_s16(_w1l), _val01234567, 4); + _sum3 = vmlal_laneq_s16(_sum3, vget_high_s16(_w1l), _val01234567, 5); + + _sum0 = vmlal_laneq_s16(_sum0, vget_low_s16(_w1h), _val01234567, 6); + _sum1 = vmlal_laneq_s16(_sum1, vget_low_s16(_w1h), _val01234567, 7); + _sum2 = vmlal_laneq_s16(_sum2, vget_high_s16(_w1h), _val01234567, 6); + _sum3 = vmlal_laneq_s16(_sum3, vget_high_s16(_w1h), _val01234567, 7); + + tmpptr += 8; + kptr0 += 32; + } + for (; j < nn1; j++) + { + int16x4_t _val0 = vdup_n_s16(tmpptr[0]); + int16x4_t _val1 = vdup_n_s16(tmpptr[1]); + int16x8_t _w01 = vmovl_s8(vld1_s8(kptr0)); + + _sum0 = vmlal_s16(_sum0, _val0, vget_low_s16(_w01)); + _sum1 = vmlal_s16(_sum1, _val1, vget_low_s16(_w01)); + _sum2 = vmlal_s16(_sum2, _val0, vget_high_s16(_w01)); + _sum3 = vmlal_s16(_sum3, _val1, vget_high_s16(_w01)); + + tmpptr += 2; + kptr0 += 8; + } + + vst1q_s32(outptr0, _sum0); + vst1q_s32(outptr0 + 4, _sum1); + vst1q_s32(outptr1, _sum2); + vst1q_s32(outptr1 + 4, _sum3); + outptr0 += 8; + outptr1 += 8; + } + for (; i < size; i++) + { + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + const signed char* kptr0 = kernel.channel(p / 2); + + int nn = (inch / 8) * maxk; + int nn4 = ((inch % 8) / 4) * maxk; + int nn1 = (inch % 4) * maxk; + +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum23 = vdupq_n_s32(0); + int32x4_t _sum45 = vdupq_n_s32(0); + int32x4_t _sum67 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x8_t _val0 = vld1_s8(tmpptr); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + int8x16_t _w45 = vld1q_s8(kptr0 + 32); + int8x16_t _w67 = vld1q_s8(kptr0 + 48); + + int8x16_t _val = vcombine_s8(_val0, _val0); + + _sum01 = vdotq_s32(_sum01, _val, _w01); + _sum23 = vdotq_s32(_sum23, _val, _w23); + _sum45 = vdotq_s32(_sum45, _val, _w45); + _sum67 = vdotq_s32(_sum67, _val, _w67); + + tmpptr += 8; + kptr0 += 64; + } + + int32x4_t _sum0 = vpaddq_s32(_sum01, _sum23); + int32x4_t _sum1 = vpaddq_s32(_sum45, _sum67); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); - "uzp1 v10.8b, v8.8b, v9.8b \n" - "uzp2 v11.8b, v8.8b, v9.8b \n" + for (int j = 0; j < nn; j++) + { + int8x8_t _val0_l_h = vld1_s8(tmpptr); + int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); + int8x16_t _w4567_l = vld1q_s8(kptr0 + 32); + int8x16_t _w4567_h = vld1q_s8(kptr0 + 48); - "uzp1 v4.16b, v0.16b, v1.16b \n" - "uzp2 v5.16b, v0.16b, v1.16b \n" - "uzp1 v6.16b, v2.16b, v3.16b \n" - "uzp2 v7.16b, v2.16b, v3.16b \n" + _sum0 = vdotq_lane_s32(_sum0, _w0123_l, _val0_l_h, 0); + _sum0 = vdotq_lane_s32(_sum0, _w0123_h, _val0_l_h, 1); + _sum1 = vdotq_lane_s32(_sum1, _w4567_l, _val0_l_h, 0); + _sum1 = vdotq_lane_s32(_sum1, _w4567_h, _val0_l_h, 1); - "uzp1 v8.8b, v10.8b, v11.8b \n" - "uzp2 v9.8b, v10.8b, v11.8b \n" + tmpptr += 8; + kptr0 += 64; + } +#endif // __ARM_FEATURE_MATMUL_INT8 - "uzp1 v0.16b, v4.16b, v5.16b \n" // 0 1 4 5 - "uzp2 v1.16b, v4.16b, v5.16b \n" // 8 9 c d + if (nn4 > 0) + { + int j = 0; + for (; j + 1 < nn4; j += 2) + { + int8x8_t _val01 = vld1_s8(tmpptr); + int8x16_t _w0 = vld1q_s8(kptr0); + int8x16_t _w1 = vld1q_s8(kptr0 + 16); + int8x16_t _w2 = vld1q_s8(kptr0 + 32); + int8x16_t _w3 = vld1q_s8(kptr0 + 48); - "mov v8.d[1], v9.d[0] \n" // _w + _sum0 = vdotq_lane_s32(_sum0, _w0, _val01, 0); + _sum1 = vdotq_lane_s32(_sum1, _w1, _val01, 0); + _sum0 = vdotq_lane_s32(_sum0, _w2, _val01, 1); + _sum1 = vdotq_lane_s32(_sum1, _w3, _val01, 1); - "uzp1 v2.16b, v6.16b, v7.16b \n" // 2 3 6 7 - "uzp2 v3.16b, v6.16b, v7.16b \n" // a b e f + tmpptr += 8; + kptr0 += 64; + } + for (; j < nn4; j++) + { + int8x8_t _val_xxx = vld1_s8(tmpptr); + int8x16_t _w0 = vld1q_s8(kptr0); + int8x16_t _w1 = vld1q_s8(kptr0 + 16); - "sdot v16.4s, v8.16b, v0.4b[0] \n" - "sdot v17.4s, v8.16b, v0.4b[1] \n" - "sdot v18.4s, v8.16b, v2.4b[0] \n" - "sdot v19.4s, v8.16b, v2.4b[1] \n" - "sdot v20.4s, v8.16b, v0.4b[2] \n" - "sdot v21.4s, v8.16b, v0.4b[3] \n" - "sdot v22.4s, v8.16b, v2.4b[2] \n" - "sdot v23.4s, v8.16b, v2.4b[3] \n" - "sdot v24.4s, v8.16b, v1.4b[0] \n" - "sdot v25.4s, v8.16b, v1.4b[1] \n" - "sdot v26.4s, v8.16b, v3.4b[0] \n" - "sdot v27.4s, v8.16b, v3.4b[1] \n" - "sdot v28.4s, v8.16b, v1.4b[2] \n" - "sdot v29.4s, v8.16b, v1.4b[3] \n" - "sdot v30.4s, v8.16b, v3.4b[2] \n" - "sdot v31.4s, v8.16b, v3.4b[3] \n" + _sum0 = vdotq_lane_s32(_sum0, _w0, _val_xxx, 0); + _sum1 = vdotq_lane_s32(_sum1, _w1, _val_xxx, 0); - "subs w4, w4, #1 \n" - "bne 4b \n" + tmpptr += 4; + kptr0 += 32; + } + } - "5: \n" + int j = 0; + for (; j + 3 < nn1; j += 4) + { + int16x4_t _val0123 = vget_low_s16(vmovl_s8(vld1_s8(tmpptr))); - "and w4, %w3, #3 \n" // w4 = remain = nn1 & 3 - "cmp w4, #0 \n" // w4 > 0 - "beq 7f \n" + int8x16_t _w0 = vld1q_s8(kptr0); + int8x16_t _w1 = vld1q_s8(kptr0 + 16); + int16x8_t _w0l = vmovl_s8(vget_low_s8(_w0)); + int16x8_t _w0h = vmovl_s8(vget_high_s8(_w0)); + int16x8_t _w1l = vmovl_s8(vget_low_s8(_w1)); + int16x8_t _w1h = vmovl_s8(vget_high_s8(_w1)); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0l), _val0123, 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0l), _val0123, 0); + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0h), _val0123, 1); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0h), _val0123, 1); + + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1l), _val0123, 2); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1l), _val0123, 2); + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1h), _val0123, 3); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1h), _val0123, 3); - "6: \n" + tmpptr += 4; + kptr0 += 32; + } + for (; j < nn1; j++) + { + int16x4_t _val = vdup_n_s16(tmpptr[0]); + int16x8_t _w01 = vmovl_s8(vld1_s8(kptr0)); - "ld1 {v1.8b}, [%5] \n" - "ld1 {v0.16b}, [%4] \n" + _sum0 = vmlal_s16(_sum0, _val, vget_low_s16(_w01)); + _sum1 = vmlal_s16(_sum1, _val, vget_high_s16(_w01)); - "sshll v1.8h, v1.8b, #0 \n" - "sshll v2.8h, v0.8b, #0 \n" - "sshll2 v3.8h, v0.16b, #0 \n" - - "smlal v16.4s, v1.4h, v2.h[0] \n" - "smlal v17.4s, v1.4h, v2.h[1] \n" - "smlal v18.4s, v1.4h, v2.h[2] \n" - "smlal v19.4s, v1.4h, v2.h[3] \n" - "smlal v20.4s, v1.4h, v2.h[4] \n" - "smlal v21.4s, v1.4h, v2.h[5] \n" - "smlal v22.4s, v1.4h, v2.h[6] \n" - "smlal v23.4s, v1.4h, v2.h[7] \n" - "smlal v24.4s, v1.4h, v3.h[0] \n" - "smlal v25.4s, v1.4h, v3.h[1] \n" - "smlal v26.4s, v1.4h, v3.h[2] \n" - "smlal v27.4s, v1.4h, v3.h[3] \n" - "smlal v28.4s, v1.4h, v3.h[4] \n" - "smlal v29.4s, v1.4h, v3.h[5] \n" - "smlal v30.4s, v1.4h, v3.h[6] \n" - "smlal v31.4s, v1.4h, v3.h[7] \n" - - "add %4, %4, #16 \n" - "add %5, %5, #4 \n" + tmpptr += 1; + kptr0 += 8; + } - "subs w4, w4, #1 \n" - "bne 6b \n" + vst1q_s32(outptr0, _sum0); + vst1q_s32(outptr1, _sum1); + outptr0 += 4; + outptr1 += 4; + } + } +#else // __ARM_FEATURE_DOTPROD + int remain_outch_start = 0; +#endif // __ARM_FEATURE_DOTPROD - "7: \n" + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = remain_outch_start; p < outch; p++) + { + int* outptr0 = top_blob.channel(p); - "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" - "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" - "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" - "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0], #64 \n" - : "=r"(outptr0), - "=r"(nn), - "=r"(nn4), - "=r"(nn1), - "=r"(tmpptr), - "=r"(kptr0) - : "0"(outptr0), - "1"(nn), - "2"(nn4), - "3"(nn1), - "4"(tmpptr), - "5"(kptr0) - : "memory", "x4", "x5", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); - } + int i = 0; +#if __aarch64__ +#if __ARM_FEATURE_DOTPROD for (; i + 7 < size; i += 8) { - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8); - const signed char* kptr0 = kernel.channel(p); + const signed char* tmpptr = tmp.channel(i / 8); + const signed char* kptr0 = kernel.channel(p / 2 + p % 2); int nn = (inch / 8) * maxk; int nn4 = ((inch % 8) / 4) * maxk; @@ -1031,6 +1543,48 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b int32x4_t _sum6 = vdupq_n_s32(0); int32x4_t _sum7 = vdupq_n_s32(0); +#if __ARM_FEATURE_MATMUL_INT8 + for (int j = 0; j < nn; j++) + { + int8x16_t _val0 = vld1q_s8(tmpptr); + int8x16_t _val1 = vld1q_s8(tmpptr + 16); + int8x16_t _val2 = vld1q_s8(tmpptr + 32); + int8x16_t _val3 = vld1q_s8(tmpptr + 48); + + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + + _sum0 = vmmlaq_s32(_sum0, _val0, _w01); + _sum1 = vmmlaq_s32(_sum1, _val0, _w23); + _sum2 = vmmlaq_s32(_sum2, _val1, _w01); + _sum3 = vmmlaq_s32(_sum3, _val1, _w23); + _sum4 = vmmlaq_s32(_sum4, _val2, _w01); + _sum5 = vmmlaq_s32(_sum5, _val2, _w23); + _sum6 = vmmlaq_s32(_sum6, _val3, _w01); + _sum7 = vmmlaq_s32(_sum7, _val3, _w23); + + tmpptr += 64; + kptr0 += 32; + } + + int32x4_t _sum0x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum0), vreinterpretq_s64_s32(_sum1))); + int32x4_t _sum1x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum0), vreinterpretq_s64_s32(_sum1))); + int32x4_t _sum2x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum2), vreinterpretq_s64_s32(_sum3))); + int32x4_t _sum3x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum2), vreinterpretq_s64_s32(_sum3))); + int32x4_t _sum4x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum4), vreinterpretq_s64_s32(_sum5))); + int32x4_t _sum5x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum4), vreinterpretq_s64_s32(_sum5))); + int32x4_t _sum6x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum6), vreinterpretq_s64_s32(_sum7))); + int32x4_t _sum7x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum6), vreinterpretq_s64_s32(_sum7))); + + _sum0 = _sum0x; + _sum1 = _sum1x; + _sum2 = _sum2x; + _sum3 = _sum3x; + _sum4 = _sum4x; + _sum5 = _sum5x; + _sum6 = _sum6x; + _sum7 = _sum7x; +#else // __ARM_FEATURE_MATMUL_INT8 for (int j = 0; j < nn; j++) { int8x16_t _val0123_l = vld1q_s8(tmpptr); @@ -1064,6 +1618,7 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b tmpptr += 64; kptr0 += 32; } +#endif // __ARM_FEATURE_MATMUL_INT8 for (int j = 0; j < nn4; j++) { @@ -1153,15 +1708,16 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b vst1q_s32(outptr0 + 28, _sum7); outptr0 += 32; } -#endif +#endif // __ARM_FEATURE_DOTPROD for (; i + 3 < size; i += 4) { #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4); + const signed char* kptr0 = kernel.channel(p / 2 + p % 2); #else const signed char* tmpptr = tmp.channel(i / 4); -#endif const signed char* kptr0 = kernel.channel(p); +#endif int nn = (inch / 8) * maxk; int nn4 = ((inch % 8) / 4) * maxk; @@ -1172,6 +1728,33 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b int32x4_t _sum2 = vdupq_n_s32(0); int32x4_t _sum3 = vdupq_n_s32(0); +#if __ARM_FEATURE_MATMUL_INT8 + for (int j = 0; j < nn; j++) + { + int8x16_t _val0 = vld1q_s8(tmpptr); + int8x16_t _val1 = vld1q_s8(tmpptr + 16); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + + _sum0 = vmmlaq_s32(_sum0, _val0, _w01); + _sum1 = vmmlaq_s32(_sum1, _val0, _w23); + _sum2 = vmmlaq_s32(_sum2, _val1, _w01); + _sum3 = vmmlaq_s32(_sum3, _val1, _w23); + + tmpptr += 32; + kptr0 += 32; + } + + int32x4_t _sum0x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum0), vreinterpretq_s64_s32(_sum1))); + int32x4_t _sum1x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum0), vreinterpretq_s64_s32(_sum1))); + int32x4_t _sum2x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum2), vreinterpretq_s64_s32(_sum3))); + int32x4_t _sum3x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum2), vreinterpretq_s64_s32(_sum3))); + + _sum0 = _sum0x; + _sum1 = _sum1x; + _sum2 = _sum2x; + _sum3 = _sum3x; +#else // __ARM_FEATURE_MATMUL_INT8 for (int j = 0; j < nn; j++) { int8x16_t _val0123_l = vld1q_s8(tmpptr); @@ -1193,6 +1776,7 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b tmpptr += 32; kptr0 += 32; } +#endif // __ARM_FEATURE_MATMUL_INT8 for (int j = 0; j < nn4; j++) { @@ -1654,14 +2238,16 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b { #if __aarch64__ #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2); + const signed char* kptr0 = kernel.channel(p / 2 + p % 2); #else const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); + const signed char* kptr0 = kernel.channel(p); #endif #else const signed char* tmpptr = tmp.channel(i / 2); -#endif const signed char* kptr0 = kernel.channel(p); +#endif int nn = (inch / 8) * maxk; int nn4 = ((inch % 8) / 4) * maxk; @@ -1670,22 +2256,42 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b int32x4_t _sum00 = vdupq_n_s32(0); int32x4_t _sum10 = vdupq_n_s32(0); #if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (int j = 0; j < nn; j++) + { + int8x16_t _val = vld1q_s8(tmpptr); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + + _sum00 = vmmlaq_s32(_sum00, _val, _w01); + _sum10 = vmmlaq_s32(_sum10, _val, _w23); + + tmpptr += 16; + kptr0 += 32; + } + + int32x4_t _sum00x = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum00), vreinterpretq_s64_s32(_sum10))); + int32x4_t _sum10x = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum00), vreinterpretq_s64_s32(_sum10))); + + _sum00 = _sum00x; + _sum10 = _sum10x; +#else // __ARM_FEATURE_MATMUL_INT8 for (int j = 0; j < nn; j++) { int8x16_t _val01_l_h = vld1q_s8(tmpptr); int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); _sum00 = vdotq_laneq_s32(_sum00, _w0123_l, _val01_l_h, 0); _sum10 = vdotq_laneq_s32(_sum10, _w0123_l, _val01_l_h, 1); - int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); - _sum00 = vdotq_laneq_s32(_sum00, _w0123_h, _val01_l_h, 2); _sum10 = vdotq_laneq_s32(_sum10, _w0123_h, _val01_l_h, 3); tmpptr += 16; kptr0 += 32; } +#endif // __ARM_FEATURE_MATMUL_INT8 if (nn4 > 0) { @@ -2197,14 +2803,16 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b { #if __aarch64__ #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + const signed char* kptr0 = kernel.channel(p / 2 + p % 2); #else const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); + const signed char* kptr0 = kernel.channel(p); #endif #else const signed char* tmpptr = tmp.channel(i / 2 + i % 2); -#endif const signed char* kptr0 = kernel.channel(p); +#endif int nn = (inch / 8) * maxk; int nn4 = ((inch % 8) / 4) * maxk; @@ -2212,21 +2820,39 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b int32x4_t _sum0 = vdupq_n_s32(0); #if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum23 = vdupq_n_s32(0); + for (int j = 0; j < nn; j++) { - int8x8_t _val0_l_h = vld1_s8(tmpptr); + int8x8_t _val0 = vld1_s8(tmpptr); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); - int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _val = vcombine_s8(_val0, _val0); - _sum0 = vdotq_lane_s32(_sum0, _w0123_l, _val0_l_h, 0); + _sum0 = vdotq_s32(_sum0, _val, _w01); + _sum23 = vdotq_s32(_sum23, _val, _w23); + + tmpptr += 8; + kptr0 += 32; + } + _sum0 = vpaddq_s32(_sum0, _sum23); +#else // __ARM_FEATURE_MATMUL_INT8 + for (int j = 0; j < nn; j++) + { + int8x8_t _val0_l_h = vld1_s8(tmpptr); + int8x16_t _w0123_l = vld1q_s8(kptr0); int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); + _sum0 = vdotq_lane_s32(_sum0, _w0123_l, _val0_l_h, 0); _sum0 = vdotq_lane_s32(_sum0, _w0123_h, _val0_l_h, 1); tmpptr += 8; kptr0 += 32; } +#endif // __ARM_FEATURE_MATMUL_INT8 if (nn4 > 0) { @@ -2437,12 +3063,22 @@ static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_b static void convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) { +#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon_i8mm(_kernel, kernel_tm, inch, outch, kernel_w, kernel_h); + return; + } +#endif + #if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __ARM_NEON && __aarch64__ && !__ARM_FEATURE_DOTPROD if (ncnn::cpu_support_arm_asimddp()) { convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon_asimddp(_kernel, kernel_tm, inch, outch, kernel_w, kernel_h); return; } +#endif #endif const int maxk = kernel_w * kernel_h; @@ -2450,33 +3086,65 @@ static void convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon(const M // interleave // src = maxk-inch-outch // dst = 8a-4b-maxk-inch/8a-outch/4b - // dst = 4a-4b-2-maxk-inch/8a-outch/4b (arm82) + // dst = 4a-4b-2aa-2bb-maxk-inch/8a-outch/8b (arm82) + // dst = 8a-8b-maxk-inch/8a-outch/8b (arm84) Mat kernel = _kernel.reshape(maxk, inch, outch); +#if __ARM_FEATURE_DOTPROD + if (outch >= 8) + { + if (inch >= 8) + kernel_tm.create(64 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, outch / 8 + (outch % 8) / 4, (size_t)1u); + else if (inch >= 4) + kernel_tm.create(32 * maxk, inch / 4 + inch % 4, outch / 8 + (outch % 8) / 4, (size_t)1u); + else + kernel_tm.create(8 * maxk, inch, outch / 8 + (outch % 8) / 4, (size_t)1u); + } + else + { + if (inch >= 8) + kernel_tm.create(32 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, outch / 4, (size_t)1u); + else if (inch >= 4) + kernel_tm.create(16 * maxk, inch / 4 + inch % 4, outch / 4, (size_t)1u); + else + kernel_tm.create(4 * maxk, inch, outch / 4, (size_t)1u); + } +#else // __ARM_FEATURE_DOTPROD if (inch >= 8) kernel_tm.create(32 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, outch / 4, (size_t)1u); else if (inch >= 4) kernel_tm.create(16 * maxk, inch / 4 + inch % 4, outch / 4, (size_t)1u); else kernel_tm.create(4 * maxk, inch, outch / 4, (size_t)1u); +#endif // __ARM_FEATURE_DOTPROD - for (int q = 0; q + 3 < outch; q += 4) + int q = 0; +#if __ARM_FEATURE_DOTPROD + for (; q + 7 < outch; q += 8) { - signed char* g00 = kernel_tm.channel(q / 4); + signed char* g00 = kernel_tm.channel(q / 8); int p = 0; for (; p + 7 < inch; p += 8) { for (int k = 0; k < maxk; k++) { -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (int i = 0; i < 8; i++) + { + for (int j = 0; j < 8; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } +#else // __ARM_FEATURE_MATMUL_INT8 for (int i = 0; i < 4; i++) { for (int j = 0; j < 4; j++) { const signed char* k00 = kernel.channel(q + i).row(p + j); - g00[0] = k00[k]; - g00++; } } @@ -2485,9 +3153,99 @@ static void convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon(const M for (int j = 4; j < 8; j++) { const signed char* k00 = kernel.channel(q + i).row(p + j); - g00[0] = k00[k]; + g00++; + } + } + for (int i = 4; i < 8; i++) + { + for (int j = 0; j < 4; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } + for (int i = 4; i < 8; i++) + { + for (int j = 4; j < 8; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } +#endif // __ARM_FEATURE_MATMUL_INT8 + } + } + for (; p + 3 < inch; p += 4) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 8; i++) + { + for (int j = 0; j < 4; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } + } + } + for (; p < inch; p++) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 8; i++) + { + const signed char* k00 = kernel.channel(q + i).row(p); + g00[0] = k00[k]; + g00++; + } + } + } + } +#endif // __ARM_FEATURE_DOTPROD + for (; q + 3 < outch; q += 4) + { +#if __ARM_FEATURE_DOTPROD + signed char* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4); +#else + signed char* g00 = kernel_tm.channel(q / 4); +#endif + int p = 0; + for (; p + 7 < inch; p += 8) + { + for (int k = 0; k < maxk; k++) + { +#if __ARM_FEATURE_MATMUL_INT8 + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 8; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } +#elif __ARM_FEATURE_DOTPROD + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 4; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } + for (int i = 0; i < 4; i++) + { + for (int j = 4; j < 8; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; g00++; } } @@ -2497,9 +3255,7 @@ static void convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon(const M for (int j = 0; j < 8; j++) { const signed char* k00 = kernel.channel(q + i).row(p + j); - g00[0] = k00[k]; - g00++; } } @@ -2515,9 +3271,7 @@ static void convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon(const M for (int j = 0; j < 4; j++) { const signed char* k00 = kernel.channel(q + i).row(p + j); - g00[0] = k00[k]; - g00++; } } @@ -2530,9 +3284,7 @@ static void convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon(const M for (int i = 0; i < 4; i++) { const signed char* k00 = kernel.channel(q + i).row(p); - g00[0] = k00[k]; - g00++; } } diff --git a/src/layer/arm/convolution_sgemm_pack8to1_int8.h b/src/layer/arm/convolution_sgemm_pack8to1_int8.h index 8c3dafc8b..cdcb6fcfa 100644 --- a/src/layer/arm/convolution_sgemm_pack8to1_int8.h +++ b/src/layer/arm/convolution_sgemm_pack8to1_int8.h @@ -12,19 +12,36 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __ARM_NEON && __aarch64__ && !__ARM_FEATURE_DOTPROD +#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 +void im2col_sgemm_pack8to1_int8_neon_i8mm(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt); +void convolution_im2col_sgemm_transform_kernel_pack8to1_int8_neon_i8mm(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD void im2col_sgemm_pack8to1_int8_neon_asimddp(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt); void convolution_im2col_sgemm_transform_kernel_pack8to1_int8_neon_asimddp(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h); #endif +#endif static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt) { -#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __ARM_NEON && __aarch64__ && !__ARM_FEATURE_DOTPROD +#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + im2col_sgemm_pack8to1_int8_neon_i8mm(bottom_im2col, top_blob, kernel, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD if (ncnn::cpu_support_arm_asimddp()) { im2col_sgemm_pack8to1_int8_neon_asimddp(bottom_im2col, top_blob, kernel, opt); return; } +#endif #endif // Mat bottom_im2col(size, maxk, inch, 8u, 8, opt.workspace_allocator); @@ -39,9 +56,7 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b Mat tmp; #if __aarch64__ #if __ARM_FEATURE_DOTPROD - if (size >= 16) - tmp.create(16 * maxk, inch, size / 16 + (size % 16) / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator); - else if (size >= 8) + if (size >= 8) tmp.create(8 * maxk, inch, size / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator); else if (size >= 4) tmp.create(4 * maxk, inch, size / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator); @@ -66,15 +81,15 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b { #if __aarch64__ #if __ARM_FEATURE_DOTPROD - int nn_size = size >> 4; + int nn_size = size >> 3; int remain_size_start = 0; #pragma omp parallel for num_threads(opt.num_threads) for (int ii = 0; ii < nn_size; ii++) { - int i = remain_size_start + ii * 16; + int i = remain_size_start + ii * 8; - signed char* tmpptr = tmp.channel(i / 16); + signed char* tmpptr = tmp.channel(i / 8); for (int q = 0; q < inch; q++) { @@ -82,48 +97,17 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b for (int k = 0; k < maxk; k++) { - // split pack8to1 to pack4 +#if __ARM_FEATURE_MATMUL_INT8 asm volatile( "prfm pldl1keep, [%0, #512] \n" - "ld2 {v0.4s, v1.4s}, [%0], #32 \n" - "ld2 {v2.4s, v3.4s}, [%0], #32 \n" - "ld2 {v4.4s, v5.4s}, [%0], #32 \n" - "ld2 {v6.4s, v7.4s}, [%0] \n" - "sub %0, %0, #96 \n" - "st1 {v0.16b}, [%1], #16 \n" - "st1 {v2.16b}, [%1], #16 \n" - "st1 {v4.16b}, [%1], #16 \n" - "st1 {v6.16b}, [%1], #16 \n" - "st1 {v1.16b}, [%1], #16 \n" - "st1 {v3.16b}, [%1], #16 \n" - "st1 {v5.16b}, [%1], #16 \n" - "st1 {v7.16b}, [%1], #16 \n" + "ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [%0] \n" + "st1 {v0.16b, v1.16b, v2.16b, v3.16b}, [%1], #64 \n" : "=r"(img0), // %0 "=r"(tmpptr) // %1 : "0"(img0), "1"(tmpptr) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); - img0 += size * 8; - } - } - } - - remain_size_start += nn_size << 4; - nn_size = (size - remain_size_start) >> 3; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int ii = 0; ii < nn_size; ii++) - { - int i = remain_size_start + ii * 8; - - signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8); - - for (int q = 0; q < inch; q++) - { - const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i * 8; - - for (int k = 0; k < maxk; k++) - { + : "memory", "v0", "v1", "v2", "v3"); +#else // __ARM_FEATURE_MATMUL_INT8 asm volatile( "prfm pldl1keep, [%0, #512] \n" "ld2 {v0.4s, v1.4s}, [%0], #32 \n" @@ -138,6 +122,7 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b : "0"(img0), "1"(tmpptr) : "memory", "v0", "v1", "v2", "v3"); +#endif // __ARM_FEATURE_MATMUL_INT8 img0 += size * 8; } } @@ -156,7 +141,7 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b int i = remain_size_start + ii * 4; #if __ARM_FEATURE_DOTPROD - signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4); + signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4); #else signed char* tmpptr = tmp.channel(i / 4); #endif @@ -167,7 +152,17 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b for (int k = 0; k < maxk; k++) { -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + asm volatile( + "prfm pldl1keep, [%0, #256] \n" + "ld1 {v0.16b, v1.16b}, [%0] \n" + "st1 {v0.16b, v1.16b}, [%1], #32 \n" + : "=r"(img0), // %0 + "=r"(tmpptr) // %1 + : "0"(img0), + "1"(tmpptr) + : "memory", "v0", "v1"); +#elif __ARM_FEATURE_DOTPROD asm volatile( "prfm pldl1keep, [%0, #256] \n" "ld2 {v0.4s, v1.4s}, [%0] \n" @@ -187,7 +182,7 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b : "0"(img0), "1"(tmpptr) : "memory", "v0", "v1"); -#endif // __ARM_FEATURE_DOTPROD +#endif img0 += size * 8; } } @@ -195,10 +190,10 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b remain_size_start += nn_size << 2; nn_size = (size - remain_size_start) >> 1; -#else +#else // __aarch64__ int remain_size_start = 0; int nn_size = (size - remain_size_start) >> 1; -#endif +#endif // __aarch64__ #pragma omp parallel for num_threads(opt.num_threads) for (int ii = 0; ii < nn_size; ii++) @@ -207,7 +202,7 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b #if __aarch64__ #if __ARM_FEATURE_DOTPROD - signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2); + signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2); #else signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); #endif @@ -222,7 +217,17 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b for (int k = 0; k < maxk; k++) { #if __aarch64__ -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + asm volatile( + "prfm pldl1keep, [%0, #128] \n" + "ld1 {v0.16b}, [%0] \n" + "st1 {v0.16b}, [%1], #16 \n" + : "=r"(img0), // %0 + "=r"(tmpptr) // %1 + : "0"(img0), + "1"(tmpptr) + : "memory", "v0"); +#elif __ARM_FEATURE_DOTPROD asm volatile( "prfm pldl1keep, [%0, #128] \n" "ld2 {v0.2s, v1.2s}, [%0] \n" @@ -242,8 +247,8 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b : "0"(img0), "1"(tmpptr) : "memory", "v0"); -#endif // __ARM_FEATURE_DOTPROD -#else +#endif +#else // __aarch64__ asm volatile( "pld [%0, #128] \n" "vld1.s8 {d0-d1}, [%0 :64] \n" @@ -253,7 +258,7 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b : "0"(img0), "1"(tmpptr) : "memory", "q0"); -#endif +#endif // __aarch64__ img0 += size * 8; } } @@ -266,7 +271,7 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b { #if __aarch64__ #if __ARM_FEATURE_DOTPROD - signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); #else signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); #endif @@ -290,7 +295,7 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b : "0"(img0), "1"(tmpptr) : "memory", "v0"); -#else +#else // __aarch64__ asm volatile( "pld [%0, #64] \n" "vld1.s8 {d0}, [%0 :64] \n" @@ -300,7 +305,7 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b : "0"(img0), "1"(tmpptr) : "memory", "d0"); -#endif +#endif // __aarch64__ img0 += size * 8; } } @@ -309,37 +314,36 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b int nn_outch = 0; int remain_outch_start = 0; - - nn_outch = outch >> 2; +#if __ARM_FEATURE_DOTPROD + nn_outch = outch / 8; + remain_outch_start = nn_outch * 8; #pragma omp parallel for num_threads(opt.num_threads) for (int pp = 0; pp < nn_outch; pp++) { - int p = pp * 4; + int p = pp * 8; int* outptr0 = top_blob.channel(p); int* outptr1 = top_blob.channel(p + 1); int* outptr2 = top_blob.channel(p + 2); int* outptr3 = top_blob.channel(p + 3); + int* outptr4 = top_blob.channel(p + 4); + int* outptr5 = top_blob.channel(p + 5); + int* outptr6 = top_blob.channel(p + 6); + int* outptr7 = top_blob.channel(p + 7); int i = 0; -#if __aarch64__ -#if __ARM_FEATURE_DOTPROD - for (; i + 15 < size; i += 16) + for (; i + 7 < size; i += 8) { - const signed char* tmpptr = tmp.channel(i / 16); - const signed char* kptr0 = kernel.channel(p / 4); + const signed char* tmpptr = tmp.channel(i / 8); + const signed char* kptr0 = kernel.channel(p / 8); int nn = inch * maxk; // inch always > 0 +#if __ARM_FEATURE_MATMUL_INT8 asm volatile( - "ld1 {v24.16b}, [%6], #16 \n" // _w0123_l - "eor v0.16b, v0.16b, v0.16b \n" "eor v1.16b, v1.16b, v1.16b \n" - - "ld1 {v16.16b}, [%5], #16 \n" // _val0123_l - "eor v2.16b, v2.16b, v2.16b \n" "eor v3.16b, v3.16b, v3.16b \n" "eor v4.16b, v4.16b, v4.16b \n" @@ -357,77 +361,144 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b "0: \n" - "ld1 {v17.16b}, [%5], #16 \n" // _val4567_l - - "sdot v0.4s, v24.16b, v16.4b[0] \n" - "sdot v1.4s, v24.16b, v16.4b[1] \n" - "sdot v2.4s, v24.16b, v16.4b[2] \n" - "sdot v3.4s, v24.16b, v16.4b[3] \n" + "ld1 {v16.16b, v17.16b, v18.16b, v19.16b}, [%9], #64 \n" // _val0 _val1 _val1 _val3 + "ld1 {v20.16b, v21.16b, v22.16b, v23.16b}, [%10], #64 \n" // _w01 _w23 _w45 _w67 + + "smmla v0.4s, v16.16b, v20.16b \n" + "smmla v1.4s, v16.16b, v21.16b \n" + "smmla v2.4s, v17.16b, v20.16b \n" + "smmla v3.4s, v17.16b, v21.16b \n" + "smmla v4.4s, v18.16b, v20.16b \n" + "smmla v5.4s, v18.16b, v21.16b \n" + "smmla v6.4s, v19.16b, v20.16b \n" + "smmla v7.4s, v19.16b, v21.16b \n" + + "subs %w8, %w8, #1 \n" + + "smmla v8.4s, v16.16b, v22.16b \n" + "smmla v9.4s, v16.16b, v23.16b \n" + "smmla v10.4s, v17.16b, v22.16b \n" + "smmla v11.4s, v17.16b, v23.16b \n" + "smmla v12.4s, v18.16b, v22.16b \n" + "smmla v13.4s, v18.16b, v23.16b \n" + "smmla v14.4s, v19.16b, v22.16b \n" + "smmla v15.4s, v19.16b, v23.16b \n" - "ld1 {v18.16b}, [%5], #16 \n" // _val891011_l - - "sdot v4.4s, v24.16b, v17.4b[0] \n" - "sdot v5.4s, v24.16b, v17.4b[1] \n" - "sdot v6.4s, v24.16b, v17.4b[2] \n" - "sdot v7.4s, v24.16b, v17.4b[3] \n" - - "ld1 {v19.16b}, [%5], #16 \n" // _val12131415_l - - "sdot v8.4s, v24.16b, v18.4b[0] \n" - "sdot v9.4s, v24.16b, v18.4b[1] \n" - - "ld1 {v25.16b}, [%6], #16 \n" // _w0123_h - - "sdot v10.4s, v24.16b, v18.4b[2] \n" - "sdot v11.4s, v24.16b, v18.4b[3] \n" - - "ld1 {v20.16b}, [%5], #16 \n" // _val0123_h - - "sdot v12.4s, v24.16b, v19.4b[0] \n" - "sdot v13.4s, v24.16b, v19.4b[1] \n" - "sdot v14.4s, v24.16b, v19.4b[2] \n" - "sdot v15.4s, v24.16b, v19.4b[3] \n" - - "ld1 {v21.16b}, [%5], #16 \n" // _val4567_h - - "sdot v0.4s, v25.16b, v20.4b[0] \n" - "sdot v1.4s, v25.16b, v20.4b[1] \n" - "sdot v2.4s, v25.16b, v20.4b[2] \n" - "sdot v3.4s, v25.16b, v20.4b[3] \n" - - "ld1 {v22.16b}, [%5], #16 \n" // _val891011_h - - "sdot v4.4s, v25.16b, v21.4b[0] \n" - "sdot v5.4s, v25.16b, v21.4b[1] \n" - "sdot v6.4s, v25.16b, v21.4b[2] \n" - "sdot v7.4s, v25.16b, v21.4b[3] \n" - - "ld1 {v23.16b}, [%5], #16 \n" // _val12131415_h - - "sdot v8.4s, v25.16b, v22.4b[0] \n" - "sdot v9.4s, v25.16b, v22.4b[1] \n" - - "ld1 {v24.16b}, [%6], #16 \n" // _w0123_l - - "sdot v10.4s, v25.16b, v22.4b[2] \n" - "sdot v11.4s, v25.16b, v22.4b[3] \n" - - "ld1 {v16.16b}, [%5], #16 \n" // _val0123_l + "bne 0b \n" - "sdot v12.4s, v25.16b, v23.4b[0] \n" - "sdot v13.4s, v25.16b, v23.4b[1] \n" + "uzp1 v16.4s, v0.4s, v2.4s \n" + "uzp2 v18.4s, v0.4s, v2.4s \n" + "uzp1 v20.4s, v1.4s, v3.4s \n" + "uzp2 v22.4s, v1.4s, v3.4s \n" + "uzp1 v17.4s, v4.4s, v6.4s \n" + "uzp2 v19.4s, v4.4s, v6.4s \n" + "uzp1 v21.4s, v5.4s, v7.4s \n" + "uzp2 v23.4s, v5.4s, v7.4s \n" + + "uzp1 v0.4s, v8.4s, v10.4s \n" + "uzp2 v2.4s, v8.4s, v10.4s \n" + "uzp1 v4.4s, v9.4s, v11.4s \n" + "uzp2 v6.4s, v9.4s, v11.4s \n" + "uzp1 v1.4s, v12.4s, v14.4s \n" + "uzp2 v3.4s, v12.4s, v14.4s \n" + "uzp1 v5.4s, v13.4s, v15.4s \n" + "uzp2 v7.4s, v13.4s, v15.4s \n" + + "st1 {v16.4s, v17.4s}, [%0], #32 \n" + "st1 {v18.4s, v19.4s}, [%1], #32 \n" + "st1 {v20.4s, v21.4s}, [%2], #32 \n" + "st1 {v22.4s, v23.4s}, [%3], #32 \n" + "st1 {v0.4s, v1.4s}, [%4], #32 \n" + "st1 {v2.4s, v3.4s}, [%5], #32 \n" + "st1 {v4.4s, v5.4s}, [%6], #32 \n" + "st1 {v6.4s, v7.4s}, [%7], #32 \n" + : "=r"(outptr0), + "=r"(outptr1), + "=r"(outptr2), + "=r"(outptr3), + "=r"(outptr4), + "=r"(outptr5), + "=r"(outptr6), + "=r"(outptr7), + "=r"(nn), + "=r"(tmpptr), + "=r"(kptr0) + : "0"(outptr0), + "1"(outptr1), + "2"(outptr2), + "3"(outptr3), + "4"(outptr4), + "5"(outptr5), + "6"(outptr6), + "7"(outptr7), + "8"(nn), + "9"(tmpptr), + "10"(kptr0) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); +#else // __ARM_FEATURE_MATMUL_INT8 + asm volatile( + "eor v0.16b, v0.16b, v0.16b \n" + "eor v1.16b, v1.16b, v1.16b \n" + "eor v2.16b, v2.16b, v2.16b \n" + "eor v3.16b, v3.16b, v3.16b \n" + "eor v4.16b, v4.16b, v4.16b \n" + "eor v5.16b, v5.16b, v5.16b \n" + "eor v6.16b, v6.16b, v6.16b \n" + "eor v7.16b, v7.16b, v7.16b \n" + "eor v8.16b, v8.16b, v8.16b \n" + "eor v9.16b, v9.16b, v9.16b \n" + "eor v10.16b, v10.16b, v10.16b \n" + "eor v11.16b, v11.16b, v11.16b \n" + "eor v12.16b, v12.16b, v12.16b \n" + "eor v13.16b, v13.16b, v13.16b \n" + "eor v14.16b, v14.16b, v14.16b \n" + "eor v15.16b, v15.16b, v15.16b \n" - "subs %w4, %w4, #1 \n" + "0: \n" - "sdot v14.4s, v25.16b, v23.4b[2] \n" - "sdot v15.4s, v25.16b, v23.4b[3] \n" + "ld1 {v16.16b, v17.16b, v18.16b, v19.16b}, [%9], #64 \n" // _val0 _val1 _val2 _val3 + "ld1 {v20.16b, v21.16b, v22.16b, v23.16b}, [%10], #64 \n" // _w01 _w23 _w45 _w67 + + "sdot v0.4s, v20.16b, v16.4b[0] \n" + "sdot v1.4s, v20.16b, v16.4b[1] \n" + "sdot v2.4s, v20.16b, v16.4b[2] \n" + "sdot v3.4s, v20.16b, v16.4b[3] \n" + "sdot v4.4s, v20.16b, v17.4b[0] \n" + "sdot v5.4s, v20.16b, v17.4b[1] \n" + "sdot v6.4s, v20.16b, v17.4b[2] \n" + "sdot v7.4s, v20.16b, v17.4b[3] \n" + + "sdot v0.4s, v21.16b, v18.4b[0] \n" + "sdot v1.4s, v21.16b, v18.4b[1] \n" + "sdot v2.4s, v21.16b, v18.4b[2] \n" + "sdot v3.4s, v21.16b, v18.4b[3] \n" + "sdot v4.4s, v21.16b, v19.4b[0] \n" + "sdot v5.4s, v21.16b, v19.4b[1] \n" + "sdot v6.4s, v21.16b, v19.4b[2] \n" + "sdot v7.4s, v21.16b, v19.4b[3] \n" + + "subs %w8, %w8, #1 \n" + + "sdot v8.4s, v22.16b, v16.4b[0] \n" + "sdot v9.4s, v22.16b, v16.4b[1] \n" + "sdot v10.4s, v22.16b, v16.4b[2] \n" + "sdot v11.4s, v22.16b, v16.4b[3] \n" + "sdot v12.4s, v22.16b, v17.4b[0] \n" + "sdot v13.4s, v22.16b, v17.4b[1] \n" + "sdot v14.4s, v22.16b, v17.4b[2] \n" + "sdot v15.4s, v22.16b, v17.4b[3] \n" + + "sdot v8.4s, v23.16b, v18.4b[0] \n" + "sdot v9.4s, v23.16b, v18.4b[1] \n" + "sdot v10.4s, v23.16b, v18.4b[2] \n" + "sdot v11.4s, v23.16b, v18.4b[3] \n" + "sdot v12.4s, v23.16b, v19.4b[0] \n" + "sdot v13.4s, v23.16b, v19.4b[1] \n" + "sdot v14.4s, v23.16b, v19.4b[2] \n" + "sdot v15.4s, v23.16b, v19.4b[3] \n" "bne 0b \n" - "sub %5, %5, #16 \n" - "sub %6, %6, #16 \n" - - // transpose 4x16 "trn1 v16.4s, v0.4s, v1.4s \n" "trn2 v17.4s, v0.4s, v1.4s \n" "trn1 v18.4s, v2.4s, v3.4s \n" @@ -436,43 +507,50 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b "trn2 v21.4s, v4.4s, v5.4s \n" "trn1 v22.4s, v6.4s, v7.4s \n" "trn2 v23.4s, v6.4s, v7.4s \n" - "trn1 v24.4s, v8.4s, v9.4s \n" - "trn2 v25.4s, v8.4s, v9.4s \n" - "trn1 v26.4s, v10.4s, v11.4s \n" - "trn2 v27.4s, v10.4s, v11.4s \n" - "trn1 v28.4s, v12.4s, v13.4s \n" - "trn2 v29.4s, v12.4s, v13.4s \n" - "trn1 v30.4s, v14.4s, v15.4s \n" - "trn2 v31.4s, v14.4s, v15.4s \n" "trn1 v0.2d, v16.2d, v18.2d \n" - "trn2 v8.2d, v16.2d, v18.2d \n" - "trn1 v4.2d, v17.2d, v19.2d \n" - "trn2 v12.2d, v17.2d, v19.2d \n" - + "trn1 v2.2d, v17.2d, v19.2d \n" + "trn2 v4.2d, v16.2d, v18.2d \n" + "trn2 v6.2d, v17.2d, v19.2d \n" "trn1 v1.2d, v20.2d, v22.2d \n" - "trn2 v9.2d, v20.2d, v22.2d \n" - "trn1 v5.2d, v21.2d, v23.2d \n" - "trn2 v13.2d, v21.2d, v23.2d \n" - - "trn1 v2.2d, v24.2d, v26.2d \n" - "trn2 v10.2d, v24.2d, v26.2d \n" - "trn1 v6.2d, v25.2d, v27.2d \n" - "trn2 v14.2d, v25.2d, v27.2d \n" - - "trn1 v3.2d, v28.2d, v30.2d \n" - "trn2 v11.2d, v28.2d, v30.2d \n" - "trn1 v7.2d, v29.2d, v31.2d \n" - "trn2 v15.2d, v29.2d, v31.2d \n" - - "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%0], #64 \n" - "st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%1], #64 \n" - "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%2], #64 \n" - "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%3], #64 \n" + "trn1 v3.2d, v21.2d, v23.2d \n" + "trn2 v5.2d, v20.2d, v22.2d \n" + "trn2 v7.2d, v21.2d, v23.2d \n" + + "trn1 v16.4s, v8.4s, v9.4s \n" + "trn2 v17.4s, v8.4s, v9.4s \n" + "trn1 v18.4s, v10.4s, v11.4s \n" + "trn2 v19.4s, v10.4s, v11.4s \n" + "trn1 v20.4s, v12.4s, v13.4s \n" + "trn2 v21.4s, v12.4s, v13.4s \n" + "trn1 v22.4s, v14.4s, v15.4s \n" + "trn2 v23.4s, v14.4s, v15.4s \n" + + "trn1 v8.2d, v16.2d, v18.2d \n" + "trn1 v10.2d, v17.2d, v19.2d \n" + "trn2 v12.2d, v16.2d, v18.2d \n" + "trn2 v14.2d, v17.2d, v19.2d \n" + "trn1 v9.2d, v20.2d, v22.2d \n" + "trn1 v11.2d, v21.2d, v23.2d \n" + "trn2 v13.2d, v20.2d, v22.2d \n" + "trn2 v15.2d, v21.2d, v23.2d \n" + + "st1 {v0.4s, v1.4s}, [%0], #32 \n" + "st1 {v2.4s, v3.4s}, [%1], #32 \n" + "st1 {v4.4s, v5.4s}, [%2], #32 \n" + "st1 {v6.4s, v7.4s}, [%3], #32 \n" + "st1 {v8.4s, v9.4s}, [%4], #32 \n" + "st1 {v10.4s, v11.4s}, [%5], #32 \n" + "st1 {v12.4s, v13.4s}, [%6], #32 \n" + "st1 {v14.4s, v15.4s}, [%7], #32 \n" : "=r"(outptr0), "=r"(outptr1), "=r"(outptr2), "=r"(outptr3), + "=r"(outptr4), + "=r"(outptr5), + "=r"(outptr6), + "=r"(outptr7), "=r"(nn), "=r"(tmpptr), "=r"(kptr0) @@ -480,18 +558,419 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b "1"(outptr1), "2"(outptr2), "3"(outptr3), - "4"(nn), - "5"(tmpptr), - "6"(kptr0) - : "memory", "x4", "x5", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + "4"(outptr4), + "5"(outptr5), + "6"(outptr6), + "7"(outptr7), + "8"(nn), + "9"(tmpptr), + "10"(kptr0) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); +#endif // __ARM_FEATURE_MATMUL_INT8 + } + for (; i + 3 < size; i += 4) + { + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4); + const signed char* kptr0 = kernel.channel(p / 8); + + int nn = inch * maxk; // inch always > 0 + +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + int32x4_t _sum4 = vdupq_n_s32(0); + int32x4_t _sum5 = vdupq_n_s32(0); + int32x4_t _sum6 = vdupq_n_s32(0); + int32x4_t _sum7 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x16_t _val0 = vld1q_s8(tmpptr); + int8x16_t _val1 = vld1q_s8(tmpptr + 16); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + int8x16_t _w45 = vld1q_s8(kptr0 + 32); + int8x16_t _w67 = vld1q_s8(kptr0 + 48); + + _sum0 = vmmlaq_s32(_sum0, _val0, _w01); + _sum1 = vmmlaq_s32(_sum1, _val0, _w23); + _sum2 = vmmlaq_s32(_sum2, _val1, _w01); + _sum3 = vmmlaq_s32(_sum3, _val1, _w23); + + _sum4 = vmmlaq_s32(_sum4, _val0, _w45); + _sum5 = vmmlaq_s32(_sum5, _val0, _w67); + _sum6 = vmmlaq_s32(_sum6, _val1, _w45); + _sum7 = vmmlaq_s32(_sum7, _val1, _w67); + + tmpptr += 32; + kptr0 += 64; + } + + int32x4x2_t _sum02 = vuzpq_s32(_sum0, _sum2); + int32x4x2_t _sum13 = vuzpq_s32(_sum1, _sum3); + int32x4x2_t _sum46 = vuzpq_s32(_sum4, _sum6); + int32x4x2_t _sum57 = vuzpq_s32(_sum5, _sum7); + + vst1q_s32(outptr0, _sum02.val[0]); + vst1q_s32(outptr1, _sum02.val[1]); + vst1q_s32(outptr2, _sum13.val[0]); + vst1q_s32(outptr3, _sum13.val[1]); + vst1q_s32(outptr4, _sum46.val[0]); + vst1q_s32(outptr5, _sum46.val[1]); + vst1q_s32(outptr6, _sum57.val[0]); + vst1q_s32(outptr7, _sum57.val[1]); + outptr0 += 4; + outptr1 += 4; + outptr2 += 4; + outptr3 += 4; + outptr4 += 4; + outptr5 += 4; + outptr6 += 4; + outptr7 += 4; +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + int32x4_t _sum4 = vdupq_n_s32(0); + int32x4_t _sum5 = vdupq_n_s32(0); + int32x4_t _sum6 = vdupq_n_s32(0); + int32x4_t _sum7 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x16_t _val0123_l = vld1q_s8(tmpptr); + int8x16_t _val0123_h = vld1q_s8(tmpptr + 16); + int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); + int8x16_t _w4567_l = vld1q_s8(kptr0 + 32); + int8x16_t _w4567_h = vld1q_s8(kptr0 + 48); + + _sum0 = vdotq_laneq_s32(_sum0, _w0123_l, _val0123_l, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0123_l, _val0123_l, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w0123_l, _val0123_l, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w0123_l, _val0123_l, 3); + _sum0 = vdotq_laneq_s32(_sum0, _w0123_h, _val0123_h, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0123_h, _val0123_h, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w0123_h, _val0123_h, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w0123_h, _val0123_h, 3); + + _sum4 = vdotq_laneq_s32(_sum4, _w4567_l, _val0123_l, 0); + _sum5 = vdotq_laneq_s32(_sum5, _w4567_l, _val0123_l, 1); + _sum6 = vdotq_laneq_s32(_sum6, _w4567_l, _val0123_l, 2); + _sum7 = vdotq_laneq_s32(_sum7, _w4567_l, _val0123_l, 3); + _sum4 = vdotq_laneq_s32(_sum4, _w4567_h, _val0123_h, 0); + _sum5 = vdotq_laneq_s32(_sum5, _w4567_h, _val0123_h, 1); + _sum6 = vdotq_laneq_s32(_sum6, _w4567_h, _val0123_h, 2); + _sum7 = vdotq_laneq_s32(_sum7, _w4567_h, _val0123_h, 3); + + tmpptr += 32; + kptr0 += 64; + } + + // transpose 4x4 + int32x4_t _sum01_0 = vtrn1q_s32(_sum0, _sum1); + int32x4_t _sum01_1 = vtrn2q_s32(_sum0, _sum1); + int32x4_t _sum23_0 = vtrn1q_s32(_sum2, _sum3); + int32x4_t _sum23_1 = vtrn2q_s32(_sum2, _sum3); + int32x4_t _sum45_0 = vtrn1q_s32(_sum4, _sum5); + int32x4_t _sum45_1 = vtrn2q_s32(_sum4, _sum5); + int32x4_t _sum67_0 = vtrn1q_s32(_sum6, _sum7); + int32x4_t _sum67_1 = vtrn2q_s32(_sum6, _sum7); + _sum0 = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum01_0), vreinterpretq_s64_s32(_sum23_0))); + _sum1 = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum01_1), vreinterpretq_s64_s32(_sum23_1))); + _sum2 = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum01_0), vreinterpretq_s64_s32(_sum23_0))); + _sum3 = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum01_1), vreinterpretq_s64_s32(_sum23_1))); + _sum4 = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum45_0), vreinterpretq_s64_s32(_sum67_0))); + _sum5 = vreinterpretq_s32_s64(vtrn1q_s64(vreinterpretq_s64_s32(_sum45_1), vreinterpretq_s64_s32(_sum67_1))); + _sum6 = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum45_0), vreinterpretq_s64_s32(_sum67_0))); + _sum7 = vreinterpretq_s32_s64(vtrn2q_s64(vreinterpretq_s64_s32(_sum45_1), vreinterpretq_s64_s32(_sum67_1))); + + vst1q_s32(outptr0, _sum0); + vst1q_s32(outptr1, _sum1); + vst1q_s32(outptr2, _sum2); + vst1q_s32(outptr3, _sum3); + vst1q_s32(outptr4, _sum4); + vst1q_s32(outptr5, _sum5); + vst1q_s32(outptr6, _sum6); + vst1q_s32(outptr7, _sum7); + outptr0 += 4; + outptr1 += 4; + outptr2 += 4; + outptr3 += 4; + outptr4 += 4; + outptr5 += 4; + outptr6 += 4; + outptr7 += 4; +#endif // __ARM_FEATURE_MATMUL_INT8 + } + for (; i + 1 < size; i += 2) + { + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2); + const signed char* kptr0 = kernel.channel(p / 8); + + int nn = inch * maxk; // inch always > 0 + +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x16_t _val = vld1q_s8(tmpptr); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + int8x16_t _w45 = vld1q_s8(kptr0 + 32); + int8x16_t _w67 = vld1q_s8(kptr0 + 48); + + _sum0 = vmmlaq_s32(_sum0, _val, _w01); + _sum1 = vmmlaq_s32(_sum1, _val, _w23); + _sum2 = vmmlaq_s32(_sum2, _val, _w45); + _sum3 = vmmlaq_s32(_sum3, _val, _w67); + + tmpptr += 16; + kptr0 += 64; + } + + int32x4x2_t _sum01 = vuzpq_s32(_sum0, _sum1); + int32x4x2_t _sum23 = vuzpq_s32(_sum2, _sum3); + + vst1_s32(outptr0, vget_low_s32(_sum01.val[0])); + vst1_s32(outptr1, vget_low_s32(_sum01.val[1])); + vst1_s32(outptr2, vget_high_s32(_sum01.val[0])); + vst1_s32(outptr3, vget_high_s32(_sum01.val[1])); + vst1_s32(outptr4, vget_low_s32(_sum23.val[0])); + vst1_s32(outptr5, vget_low_s32(_sum23.val[1])); + vst1_s32(outptr6, vget_high_s32(_sum23.val[0])); + vst1_s32(outptr7, vget_high_s32(_sum23.val[1])); + outptr0 += 2; + outptr1 += 2; + outptr2 += 2; + outptr3 += 2; + outptr4 += 2; + outptr5 += 2; + outptr6 += 2; + outptr7 += 2; +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x16_t _val01_l_h = vld1q_s8(tmpptr); + int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); + int8x16_t _w4567_l = vld1q_s8(kptr0 + 32); + int8x16_t _w4567_h = vld1q_s8(kptr0 + 48); + + _sum0 = vdotq_laneq_s32(_sum0, _w0123_l, _val01_l_h, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0123_l, _val01_l_h, 1); + _sum0 = vdotq_laneq_s32(_sum0, _w0123_h, _val01_l_h, 2); + _sum1 = vdotq_laneq_s32(_sum1, _w0123_h, _val01_l_h, 3); + + _sum2 = vdotq_laneq_s32(_sum2, _w4567_l, _val01_l_h, 0); + _sum3 = vdotq_laneq_s32(_sum3, _w4567_l, _val01_l_h, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w4567_h, _val01_l_h, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w4567_h, _val01_l_h, 3); + + tmpptr += 16; + kptr0 += 64; + } + + int32x4x2_t _sum01 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _sum23 = vzipq_s32(_sum2, _sum3); + + vst1_s32(outptr0, vget_low_s32(_sum01.val[0])); + vst1_s32(outptr1, vget_high_s32(_sum01.val[0])); + vst1_s32(outptr2, vget_low_s32(_sum01.val[1])); + vst1_s32(outptr3, vget_high_s32(_sum01.val[1])); + vst1_s32(outptr4, vget_low_s32(_sum23.val[0])); + vst1_s32(outptr5, vget_high_s32(_sum23.val[0])); + vst1_s32(outptr6, vget_low_s32(_sum23.val[1])); + vst1_s32(outptr7, vget_high_s32(_sum23.val[1])); + outptr0 += 2; + outptr1 += 2; + outptr2 += 2; + outptr3 += 2; + outptr4 += 2; + outptr5 += 2; + outptr6 += 2; + outptr7 += 2; +#endif // __ARM_FEATURE_MATMUL_INT8 + } + for (; i < size; i++) + { + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + const signed char* kptr0 = kernel.channel(p / 8); + + int nn = inch * maxk; // inch always > 0 + +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum23 = vdupq_n_s32(0); + int32x4_t _sum45 = vdupq_n_s32(0); + int32x4_t _sum67 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x8_t _val0 = vld1_s8(tmpptr); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + int8x16_t _w45 = vld1q_s8(kptr0 + 32); + int8x16_t _w67 = vld1q_s8(kptr0 + 48); + + int8x16_t _val = vcombine_s8(_val0, _val0); + + _sum01 = vdotq_s32(_sum01, _val, _w01); + _sum23 = vdotq_s32(_sum23, _val, _w23); + _sum45 = vdotq_s32(_sum45, _val, _w45); + _sum67 = vdotq_s32(_sum67, _val, _w67); + + tmpptr += 8; + kptr0 += 64; + } + + int32x4_t _s0123 = vpaddq_s32(_sum01, _sum23); + int32x4_t _s4567 = vpaddq_s32(_sum45, _sum67); + + outptr0[0] = vgetq_lane_s32(_s0123, 0); + outptr1[0] = vgetq_lane_s32(_s0123, 1); + outptr2[0] = vgetq_lane_s32(_s0123, 2); + outptr3[0] = vgetq_lane_s32(_s0123, 3); + outptr4[0] = vgetq_lane_s32(_s4567, 0); + outptr5[0] = vgetq_lane_s32(_s4567, 1); + outptr6[0] = vgetq_lane_s32(_s4567, 2); + outptr7[0] = vgetq_lane_s32(_s4567, 3); + outptr0 += 1; + outptr1 += 1; + outptr2 += 1; + outptr3 += 1; + outptr4 += 1; + outptr5 += 1; + outptr6 += 1; + outptr7 += 1; +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x8_t _val0_l_h = vld1_s8(tmpptr); + + int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); + int8x16_t _w4567_l = vld1q_s8(kptr0 + 32); + int8x16_t _w4567_h = vld1q_s8(kptr0 + 48); + + _sum0 = vdotq_lane_s32(_sum0, _w0123_l, _val0_l_h, 0); + _sum0 = vdotq_lane_s32(_sum0, _w0123_h, _val0_l_h, 1); + _sum1 = vdotq_lane_s32(_sum1, _w4567_l, _val0_l_h, 0); + _sum1 = vdotq_lane_s32(_sum1, _w4567_h, _val0_l_h, 1); + + tmpptr += 8; + kptr0 += 64; + } + + outptr0[0] = vgetq_lane_s32(_sum0, 0); + outptr1[0] = vgetq_lane_s32(_sum0, 1); + outptr2[0] = vgetq_lane_s32(_sum0, 2); + outptr3[0] = vgetq_lane_s32(_sum0, 3); + outptr4[0] = vgetq_lane_s32(_sum1, 0); + outptr5[0] = vgetq_lane_s32(_sum1, 1); + outptr6[0] = vgetq_lane_s32(_sum1, 2); + outptr7[0] = vgetq_lane_s32(_sum1, 3); + outptr0 += 1; + outptr1 += 1; + outptr2 += 1; + outptr3 += 1; + outptr4 += 1; + outptr5 += 1; + outptr6 += 1; + outptr7 += 1; +#endif // __ARM_FEATURE_MATMUL_INT8 } + } +#endif // __ARM_FEATURE_DOTPROD + + nn_outch = (outch - remain_outch_start) >> 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int pp = 0; pp < nn_outch; pp++) + { + int p = remain_outch_start + pp * 4; + + int* outptr0 = top_blob.channel(p); + int* outptr1 = top_blob.channel(p + 1); + int* outptr2 = top_blob.channel(p + 2); + int* outptr3 = top_blob.channel(p + 3); + + int i = 0; +#if __aarch64__ +#if __ARM_FEATURE_DOTPROD for (; i + 7 < size; i += 8) { - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8); - const signed char* kptr0 = kernel.channel(p / 4); + const signed char* tmpptr = tmp.channel(i / 8); + const signed char* kptr0 = kernel.channel(p / 8 + (p % 8) / 4); int nn = inch * maxk; // inch always > 0 +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + int32x4_t _sum4 = vdupq_n_s32(0); + int32x4_t _sum5 = vdupq_n_s32(0); + int32x4_t _sum6 = vdupq_n_s32(0); + int32x4_t _sum7 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x16_t _val0 = vld1q_s8(tmpptr); + int8x16_t _val1 = vld1q_s8(tmpptr + 16); + int8x16_t _val2 = vld1q_s8(tmpptr + 32); + int8x16_t _val3 = vld1q_s8(tmpptr + 48); + + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + + _sum0 = vmmlaq_s32(_sum0, _val0, _w01); + _sum1 = vmmlaq_s32(_sum1, _val0, _w23); + _sum2 = vmmlaq_s32(_sum2, _val1, _w01); + _sum3 = vmmlaq_s32(_sum3, _val1, _w23); + _sum4 = vmmlaq_s32(_sum4, _val2, _w01); + _sum5 = vmmlaq_s32(_sum5, _val2, _w23); + _sum6 = vmmlaq_s32(_sum6, _val3, _w01); + _sum7 = vmmlaq_s32(_sum7, _val3, _w23); + + tmpptr += 64; + kptr0 += 32; + } + + int32x4x2_t _sum02 = vuzpq_s32(_sum0, _sum2); + int32x4x2_t _sum13 = vuzpq_s32(_sum1, _sum3); + int32x4x2_t _sum46 = vuzpq_s32(_sum4, _sum6); + int32x4x2_t _sum57 = vuzpq_s32(_sum5, _sum7); + + vst1q_s32(outptr0, _sum02.val[0]); + vst1q_s32(outptr1, _sum02.val[1]); + vst1q_s32(outptr2, _sum13.val[0]); + vst1q_s32(outptr3, _sum13.val[1]); + vst1q_s32(outptr0 + 4, _sum46.val[0]); + vst1q_s32(outptr1 + 4, _sum46.val[1]); + vst1q_s32(outptr2 + 4, _sum57.val[0]); + vst1q_s32(outptr3 + 4, _sum57.val[1]); + outptr0 += 8; + outptr1 += 8; + outptr2 += 8; + outptr3 += 8; +#else // __ARM_FEATURE_MATMUL_INT8 int32x4_t _sum0 = vdupq_n_s32(0); int32x4_t _sum1 = vdupq_n_s32(0); int32x4_t _sum2 = vdupq_n_s32(0); @@ -561,20 +1040,55 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b outptr1 += 8; outptr2 += 8; outptr3 += 8; +#endif // __ARM_FEATURE_MATMUL_INT8 } #endif for (; i + 3 < size; i += 4) { #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4); + const signed char* kptr0 = kernel.channel(p / 8 + (p % 8) / 4); #else const signed char* tmpptr = tmp.channel(i / 4); -#endif const signed char* kptr0 = kernel.channel(p / 4); +#endif int nn = inch * maxk; // inch always > 0 -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x16_t _val0 = vld1q_s8(tmpptr); + int8x16_t _val1 = vld1q_s8(tmpptr + 16); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + + _sum0 = vmmlaq_s32(_sum0, _val0, _w01); + _sum1 = vmmlaq_s32(_sum1, _val0, _w23); + _sum2 = vmmlaq_s32(_sum2, _val1, _w01); + _sum3 = vmmlaq_s32(_sum3, _val1, _w23); + + tmpptr += 32; + kptr0 += 32; + } + + int32x4x2_t _sum02 = vuzpq_s32(_sum0, _sum2); + int32x4x2_t _sum13 = vuzpq_s32(_sum1, _sum3); + + vst1q_s32(outptr0, _sum02.val[0]); + vst1q_s32(outptr1, _sum02.val[1]); + vst1q_s32(outptr2, _sum13.val[0]); + vst1q_s32(outptr3, _sum13.val[1]); + outptr0 += 4; + outptr1 += 4; + outptr2 += 4; + outptr3 += 4; +#elif __ARM_FEATURE_DOTPROD int32x4_t _sum0 = vdupq_n_s32(0); int32x4_t _sum1 = vdupq_n_s32(0); int32x4_t _sum2 = vdupq_n_s32(0); @@ -828,19 +1342,48 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b { #if __aarch64__ #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2); + const signed char* kptr0 = kernel.channel(p / 8 + (p % 8) / 4); #else const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); + const signed char* kptr0 = kernel.channel(p / 4); #endif #else const signed char* tmpptr = tmp.channel(i / 2); -#endif const signed char* kptr0 = kernel.channel(p / 4); +#endif int nn = inch * maxk; // inch always > 0 #if __aarch64__ -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x16_t _val = vld1q_s8(tmpptr); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + + _sum0 = vmmlaq_s32(_sum0, _val, _w01); + _sum1 = vmmlaq_s32(_sum1, _val, _w23); + + tmpptr += 16; + kptr0 += 32; + } + + int32x4x2_t _sum01 = vuzpq_s32(_sum0, _sum1); + + vst1_s32(outptr0, vget_low_s32(_sum01.val[0])); + vst1_s32(outptr1, vget_low_s32(_sum01.val[1])); + vst1_s32(outptr2, vget_high_s32(_sum01.val[0])); + vst1_s32(outptr3, vget_high_s32(_sum01.val[1])); + outptr0 += 2; + outptr1 += 2; + outptr2 += 2; + outptr3 += 2; +#elif __ARM_FEATURE_DOTPROD int32x4_t _sum0 = vdupq_n_s32(0); int32x4_t _sum1 = vdupq_n_s32(0); @@ -1129,18 +1672,49 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b { #if __aarch64__ #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + const signed char* kptr0 = kernel.channel(p / 8 + (p % 8) / 4); #else const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); + const signed char* kptr0 = kernel.channel(p / 4); #endif #else const signed char* tmpptr = tmp.channel(i / 2 + i % 2); -#endif const signed char* kptr0 = kernel.channel(p / 4); +#endif int nn = inch * maxk; // inch always > 0 -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum23 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x8_t _val0 = vld1_s8(tmpptr); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + + int8x16_t _val = vcombine_s8(_val0, _val0); + + _sum01 = vdotq_s32(_sum01, _val, _w01); + _sum23 = vdotq_s32(_sum23, _val, _w23); + + tmpptr += 8; + kptr0 += 32; + } + + int32x4_t _s0123 = vpaddq_s32(_sum01, _sum23); + + vst1q_lane_s32(outptr0, _s0123, 0); + vst1q_lane_s32(outptr1, _s0123, 1); + vst1q_lane_s32(outptr2, _s0123, 2); + vst1q_lane_s32(outptr3, _s0123, 3); + outptr0 += 1; + outptr1 += 1; + outptr2 += 1; + outptr3 += 1; +#elif __ARM_FEATURE_DOTPROD int32x4_t _sum0 = vdupq_n_s32(0); for (int j = 0; j < nn; j++) @@ -1259,57 +1833,52 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b int i = 0; #if __aarch64__ #if __ARM_FEATURE_DOTPROD - for (; i + 15 < size; i += 16) + for (; i + 7 < size; i += 8) { - const signed char* tmpptr = tmp.channel(i / 16); - const signed char* kptr0 = kernel.channel(p / 4 + p % 4); + const signed char* tmpptr = tmp.channel(i / 8); + const signed char* kptr0 = kernel.channel(p / 8 + (p % 8) / 4 + p % 4); int nn = inch * maxk; // inch always > 0 - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); +#if __ARM_FEATURE_MATMUL_INT8 + int32x2_t _sum0 = vdup_n_s32(0); + int32x2_t _sum1 = vdup_n_s32(0); + int32x2_t _sum2 = vdup_n_s32(0); + int32x2_t _sum3 = vdup_n_s32(0); + int32x2_t _sum4 = vdup_n_s32(0); + int32x2_t _sum5 = vdup_n_s32(0); + int32x2_t _sum6 = vdup_n_s32(0); + int32x2_t _sum7 = vdup_n_s32(0); int j = 0; for (; j < nn; j++) { - int8x16_t _val0123_l = vld1q_s8(tmpptr); - int8x16_t _val4567_l = vld1q_s8(tmpptr + 16); - int8x16_t _val89ab_l = vld1q_s8(tmpptr + 32); - int8x16_t _valcdef_l = vld1q_s8(tmpptr + 48); - int8x16_t _val0123_h = vld1q_s8(tmpptr + 64); - int8x16_t _val4567_h = vld1q_s8(tmpptr + 80); - int8x16_t _val89ab_h = vld1q_s8(tmpptr + 96); - int8x16_t _valcdef_h = vld1q_s8(tmpptr + 112); - int8x8_t _w_lh = vld1_s8(kptr0); + int8x16_t _val0 = vld1q_s8(tmpptr); + int8x16_t _val1 = vld1q_s8(tmpptr + 16); + int8x16_t _val2 = vld1q_s8(tmpptr + 32); + int8x16_t _val3 = vld1q_s8(tmpptr + 48); + int8x8_t _w = vld1_s8(kptr0); - _sum0 = vdotq_lane_s32(_sum0, _val0123_l, _w_lh, 0); - _sum1 = vdotq_lane_s32(_sum1, _val4567_l, _w_lh, 0); - _sum2 = vdotq_lane_s32(_sum2, _val89ab_l, _w_lh, 0); - _sum3 = vdotq_lane_s32(_sum3, _valcdef_l, _w_lh, 0); - _sum0 = vdotq_lane_s32(_sum0, _val0123_h, _w_lh, 1); - _sum1 = vdotq_lane_s32(_sum1, _val4567_h, _w_lh, 1); - _sum2 = vdotq_lane_s32(_sum2, _val89ab_h, _w_lh, 1); - _sum3 = vdotq_lane_s32(_sum3, _valcdef_h, _w_lh, 1); - - tmpptr += 128; + _sum0 = vdot_s32(_sum0, vget_low_s8(_val0), _w); + _sum1 = vdot_s32(_sum1, vget_high_s8(_val0), _w); + _sum2 = vdot_s32(_sum2, vget_low_s8(_val1), _w); + _sum3 = vdot_s32(_sum3, vget_high_s8(_val1), _w); + _sum4 = vdot_s32(_sum4, vget_low_s8(_val2), _w); + _sum5 = vdot_s32(_sum5, vget_high_s8(_val2), _w); + _sum6 = vdot_s32(_sum6, vget_low_s8(_val3), _w); + _sum7 = vdot_s32(_sum7, vget_high_s8(_val3), _w); + + tmpptr += 64; kptr0 += 8; } - vst1q_s32(outptr0, _sum0); - vst1q_s32(outptr0 + 4, _sum1); - vst1q_s32(outptr0 + 8, _sum2); - vst1q_s32(outptr0 + 12, _sum3); - outptr0 += 16; - } - for (; i + 7 < size; i += 8) - { - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8); - const signed char* kptr0 = kernel.channel(p / 4 + p % 4); - - int nn = inch * maxk; // inch always > 0 + int32x4_t _ss = vpaddq_s32(vcombine_s32(_sum0, _sum1), vcombine_s32(_sum2, _sum3)); + int32x4_t _ss2 = vpaddq_s32(vcombine_s32(_sum4, _sum5), vcombine_s32(_sum6, _sum7)); + vst1q_s32(outptr0, _ss); + vst1q_s32(outptr0 + 4, _ss2); + outptr0 += 8; +#else // __ARM_FEATURE_MATMUL_INT8 int32x4_t _sum0 = vdupq_n_s32(0); int32x4_t _sum1 = vdupq_n_s32(0); int32x4_t _sum2 = vdupq_n_s32(0); @@ -1339,20 +1908,48 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b vst1q_s32(outptr0, _sum0); vst1q_s32(outptr0 + 4, _sum1); outptr0 += 8; +#endif // __ARM_FEATURE_MATMUL_INT8 } #endif // __ARM_FEATURE_DOTPROD for (; i + 3 < size; i += 4) { #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4); + const signed char* kptr0 = kernel.channel(p / 8 + (p % 8) / 4 + p % 4); #else const signed char* tmpptr = tmp.channel(i / 4); -#endif const signed char* kptr0 = kernel.channel(p / 4 + p % 4); +#endif int nn = inch * maxk; // inch always > 0 -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int32x2_t _sum0 = vdup_n_s32(0); + int32x2_t _sum1 = vdup_n_s32(0); + int32x2_t _sum2 = vdup_n_s32(0); + int32x2_t _sum3 = vdup_n_s32(0); + + int j = 0; + for (; j < nn; j++) + { + int8x16_t _val0 = vld1q_s8(tmpptr); + int8x16_t _val1 = vld1q_s8(tmpptr + 16); + int8x8_t _w = vld1_s8(kptr0); + + _sum0 = vdot_s32(_sum0, vget_low_s8(_val0), _w); + _sum1 = vdot_s32(_sum1, vget_high_s8(_val0), _w); + _sum2 = vdot_s32(_sum2, vget_low_s8(_val1), _w); + _sum3 = vdot_s32(_sum3, vget_high_s8(_val1), _w); + + tmpptr += 32; + kptr0 += 8; + } + + int32x4_t _ss = vpaddq_s32(vcombine_s32(_sum0, _sum1), vcombine_s32(_sum2, _sum3)); + + vst1q_s32(outptr0, _ss); + outptr0 += 4; +#elif __ARM_FEATURE_DOTPROD int32x4_t _sum0 = vdupq_n_s32(0); int32x4_t _sum1 = vdupq_n_s32(0); @@ -1461,18 +2058,41 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b { #if __aarch64__ #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2); + const signed char* kptr0 = kernel.channel(p / 8 + (p % 8) / 4 + p % 4); #else const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); + const signed char* kptr0 = kernel.channel(p / 4 + p % 4); #endif #else const signed char* tmpptr = tmp.channel(i / 2); -#endif const signed char* kptr0 = kernel.channel(p / 4 + p % 4); +#endif int nn = inch * maxk; // inch always > 0 -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int32x2_t _sum0 = vdup_n_s32(0); + int32x2_t _sum1 = vdup_n_s32(0); + + int j = 0; + for (; j < nn; j++) + { + int8x16_t _val = vld1q_s8(tmpptr); + int8x8_t _w = vld1_s8(kptr0); + + _sum0 = vdot_s32(_sum0, vget_low_s8(_val), _w); + _sum1 = vdot_s32(_sum1, vget_high_s8(_val), _w); + + tmpptr += 16; + kptr0 += 8; + } + + int32x2_t _ss = vpadd_s32(_sum0, _sum1); + + vst1_s32(outptr0, _ss); + outptr0 += 2; +#elif __ARM_FEATURE_DOTPROD int32x2_t _sum0 = vdup_n_s32(0); int32x2_t _sum1 = vdup_n_s32(0); @@ -1552,14 +2172,16 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b { #if __aarch64__ #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + const signed char* kptr0 = kernel.channel(p / 8 + (p % 8) / 4 + p % 4); #else const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); + const signed char* kptr0 = kernel.channel(p / 4 + p % 4); #endif #else const signed char* tmpptr = tmp.channel(i / 2 + i % 2); -#endif const signed char* kptr0 = kernel.channel(p / 4 + p % 4); +#endif int nn = inch * maxk; // inch always > 0 @@ -1644,36 +2266,131 @@ static void im2col_sgemm_pack8to1_int8_neon(const Mat& bottom_im2col, Mat& top_b static void convolution_im2col_sgemm_transform_kernel_pack8to1_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) { -#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __ARM_NEON && __aarch64__ && !__ARM_FEATURE_DOTPROD +#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + convolution_im2col_sgemm_transform_kernel_pack8to1_int8_neon_i8mm(_kernel, kernel_tm, inch, outch, kernel_w, kernel_h); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD if (ncnn::cpu_support_arm_asimddp()) { convolution_im2col_sgemm_transform_kernel_pack8to1_int8_neon_asimddp(_kernel, kernel_tm, inch, outch, kernel_w, kernel_h); return; } +#endif #endif const int maxk = kernel_w * kernel_h; // interleave // src = maxk-inch-outch - // dst = 8a-4b-maxk-inch/8a-outch/4b - // dst = 4a-4b-2-maxk-inch/8a-outch/4b (arm82) + // dst = 8a-4b-maxk-inch/8a-outch/4 + // dst = 4a-4b-2aa-2bb-maxk-inch/8a-outch/8b (arm82) + // dst = 8a-8b-maxk-inch/8a-outch/8b (arm84) Mat kernel = _kernel.reshape(maxk, inch, outch); +#if __ARM_FEATURE_DOTPROD + if (outch >= 8) + kernel_tm.create(64 * maxk, inch / 8, outch / 8 + (outch % 8) / 4 + outch % 4, (size_t)1u); + else if (outch >= 4) + kernel_tm.create(32 * maxk, inch / 8, outch / 4 + outch % 4, (size_t)1u); + else + kernel_tm.create(8 * maxk, inch / 8, outch, (size_t)1u); +#else if (outch >= 4) kernel_tm.create(32 * maxk, inch / 8, outch / 4 + outch % 4, (size_t)1u); else kernel_tm.create(8 * maxk, inch / 8, outch, (size_t)1u); +#endif int q = 0; +#if __ARM_FEATURE_DOTPROD + for (; q + 7 < outch; q += 8) + { + signed char* g00 = kernel_tm.channel(q / 8); + + for (int p = 0; p + 7 < inch; p += 8) + { + for (int k = 0; k < maxk; k++) + { +#if __ARM_FEATURE_MATMUL_INT8 + for (int i = 0; i < 8; i++) + { + for (int j = 0; j < 8; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } +#else // __ARM_FEATURE_MATMUL_INT8 + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 4; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } + for (int i = 0; i < 4; i++) + { + for (int j = 4; j < 8; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } + for (int i = 4; i < 8; i++) + { + for (int j = 0; j < 4; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } + for (int i = 4; i < 8; i++) + { + for (int j = 4; j < 8; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } +#endif // __ARM_FEATURE_MATMUL_INT8 + } + } + } +#endif // __ARM_FEATURE_DOTPROD for (; q + 3 < outch; q += 4) { +#if __ARM_FEATURE_DOTPROD + signed char* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4); +#else signed char* g00 = kernel_tm.channel(q / 4); +#endif for (int p = 0; p + 7 < inch; p += 8) { for (int k = 0; k < maxk; k++) { -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 8; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } +#elif __ARM_FEATURE_DOTPROD for (int i = 0; i < 4; i++) { for (int j = 0; j < 4; j++) @@ -1715,7 +2432,11 @@ static void convolution_im2col_sgemm_transform_kernel_pack8to1_int8_neon(const M // TODO unroll 2 for (; q < outch; q++) { +#if __ARM_FEATURE_DOTPROD + signed char* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4 + q % 4); +#else signed char* g00 = kernel_tm.channel(q / 4 + q % 4); +#endif for (int p = 0; p + 7 < inch; p += 8) { diff --git a/src/layer/arm/convolution_sgemm_pack8to4_int8.h b/src/layer/arm/convolution_sgemm_pack8to4_int8.h index 17f9d09a4..2bf34441f 100644 --- a/src/layer/arm/convolution_sgemm_pack8to4_int8.h +++ b/src/layer/arm/convolution_sgemm_pack8to4_int8.h @@ -12,19 +12,36 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __ARM_NEON && __aarch64__ && !__ARM_FEATURE_DOTPROD +#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 +void im2col_sgemm_pack8to4_int8_neon_i8mm(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt); +void convolution_im2col_sgemm_transform_kernel_pack8to4_int8_neon_i8mm(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD void im2col_sgemm_pack8to4_int8_neon_asimddp(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt); void convolution_im2col_sgemm_transform_kernel_pack8to4_int8_neon_asimddp(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h); #endif +#endif static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt) { -#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __ARM_NEON && __aarch64__ && !__ARM_FEATURE_DOTPROD +#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + im2col_sgemm_pack8to4_int8_neon_i8mm(bottom_im2col, top_blob, kernel, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD if (ncnn::cpu_support_arm_asimddp()) { im2col_sgemm_pack8to4_int8_neon_asimddp(bottom_im2col, top_blob, kernel, opt); return; } +#endif #endif // Mat bottom_im2col(size, maxk, inch, 8u, 8, opt.workspace_allocator); @@ -39,9 +56,7 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b Mat tmp; #if __aarch64__ #if __ARM_FEATURE_DOTPROD - if (size >= 16) - tmp.create(16 * maxk, inch, size / 16 + (size % 16) / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator); - else if (size >= 8) + if (size >= 8) tmp.create(8 * maxk, inch, size / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator); else if (size >= 4) tmp.create(4 * maxk, inch, size / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator); @@ -66,15 +81,15 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b { #if __aarch64__ #if __ARM_FEATURE_DOTPROD - int nn_size = size >> 4; + int nn_size = size >> 3; int remain_size_start = 0; #pragma omp parallel for num_threads(opt.num_threads) for (int ii = 0; ii < nn_size; ii++) { - int i = remain_size_start + ii * 16; + int i = remain_size_start + ii * 8; - signed char* tmpptr = tmp.channel(i / 16); + signed char* tmpptr = tmp.channel(i / 8); for (int q = 0; q < inch; q++) { @@ -82,48 +97,17 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b for (int k = 0; k < maxk; k++) { - // split pack8 to pack4 +#if __ARM_FEATURE_MATMUL_INT8 asm volatile( "prfm pldl1keep, [%0, #512] \n" - "ld2 {v0.4s, v1.4s}, [%0], #32 \n" - "ld2 {v2.4s, v3.4s}, [%0], #32 \n" - "ld2 {v4.4s, v5.4s}, [%0], #32 \n" - "ld2 {v6.4s, v7.4s}, [%0] \n" - "sub %0, %0, #96 \n" - "st1 {v0.16b}, [%1], #16 \n" - "st1 {v2.16b}, [%1], #16 \n" - "st1 {v4.16b}, [%1], #16 \n" - "st1 {v6.16b}, [%1], #16 \n" - "st1 {v1.16b}, [%1], #16 \n" - "st1 {v3.16b}, [%1], #16 \n" - "st1 {v5.16b}, [%1], #16 \n" - "st1 {v7.16b}, [%1], #16 \n" + "ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [%0] \n" + "st1 {v0.16b, v1.16b, v2.16b, v3.16b}, [%1], #64 \n" : "=r"(img0), // %0 "=r"(tmpptr) // %1 : "0"(img0), "1"(tmpptr) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); - img0 += size * 8; - } - } - } - - remain_size_start += nn_size << 4; - nn_size = (size - remain_size_start) >> 3; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int ii = 0; ii < nn_size; ii++) - { - int i = remain_size_start + ii * 8; - - signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8); - - for (int q = 0; q < inch; q++) - { - const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i * 8; - - for (int k = 0; k < maxk; k++) - { + : "memory", "v0", "v1", "v2", "v3"); +#else // __ARM_FEATURE_MATMUL_INT8 asm volatile( "prfm pldl1keep, [%0, #512] \n" "ld2 {v0.4s, v1.4s}, [%0], #32 \n" @@ -138,6 +122,7 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b : "0"(img0), "1"(tmpptr) : "memory", "v0", "v1", "v2", "v3"); +#endif // __ARM_FEATURE_MATMUL_INT8 img0 += size * 8; } } @@ -156,7 +141,7 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b int i = remain_size_start + ii * 4; #if __ARM_FEATURE_DOTPROD - signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4); + signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4); #else signed char* tmpptr = tmp.channel(i / 4); #endif @@ -167,7 +152,17 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b for (int k = 0; k < maxk; k++) { -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + asm volatile( + "prfm pldl1keep, [%0, #256] \n" + "ld1 {v0.16b, v1.16b}, [%0] \n" + "st1 {v0.16b, v1.16b}, [%1], #32 \n" + : "=r"(img0), // %0 + "=r"(tmpptr) // %1 + : "0"(img0), + "1"(tmpptr) + : "memory", "v0", "v1"); +#elif __ARM_FEATURE_DOTPROD asm volatile( "prfm pldl1keep, [%0, #256] \n" "ld2 {v0.4s, v1.4s}, [%0] \n" @@ -187,7 +182,7 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b : "0"(img0), "1"(tmpptr) : "memory", "v0", "v1"); -#endif // __ARM_FEATURE_DOTPROD +#endif img0 += size * 8; } } @@ -195,10 +190,10 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b remain_size_start += nn_size << 2; nn_size = (size - remain_size_start) >> 1; -#else +#else // __aarch64__ int remain_size_start = 0; int nn_size = (size - remain_size_start) >> 1; -#endif +#endif // __aarch64__ #pragma omp parallel for num_threads(opt.num_threads) for (int ii = 0; ii < nn_size; ii++) @@ -207,7 +202,7 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b #if __aarch64__ #if __ARM_FEATURE_DOTPROD - signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2); + signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2); #else signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); #endif @@ -222,7 +217,17 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b for (int k = 0; k < maxk; k++) { #if __aarch64__ -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + asm volatile( + "prfm pldl1keep, [%0, #128] \n" + "ld1 {v0.16b}, [%0] \n" + "st1 {v0.16b}, [%1], #16 \n" + : "=r"(img0), // %0 + "=r"(tmpptr) // %1 + : "0"(img0), + "1"(tmpptr) + : "memory", "v0"); +#elif __ARM_FEATURE_DOTPROD asm volatile( "prfm pldl1keep, [%0, #128] \n" "ld2 {v0.2s, v1.2s}, [%0] \n" @@ -242,8 +247,8 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b : "0"(img0), "1"(tmpptr) : "memory", "v0"); -#endif // __ARM_FEATURE_DOTPROD -#else +#endif +#else // __aarch64__ asm volatile( "pld [%0, #128] \n" "vld1.s8 {d0-d1}, [%0 :64] \n" @@ -253,7 +258,7 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b : "0"(img0), "1"(tmpptr) : "memory", "q0"); -#endif +#endif // __aarch64__ img0 += size * 8; } } @@ -266,7 +271,7 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b { #if __aarch64__ #if __ARM_FEATURE_DOTPROD - signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); #else signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); #endif @@ -307,29 +312,95 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b } } +#if __ARM_FEATURE_DOTPROD + int nn_outch = outch / 2; + int remain_outch_start = nn_outch * 2; + #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) + for (int pp = 0; pp < nn_outch; pp++) { + int p = pp * 2; + int* outptr0 = top_blob.channel(p); + int* outptr1 = top_blob.channel(p + 1); int i = 0; -#if __aarch64__ -#if __ARM_FEATURE_DOTPROD - for (; i + 15 < size; i += 16) + for (; i + 7 < size; i += 8) { - const signed char* tmpptr = tmp.channel(i / 16); - const signed char* kptr0 = kernel.channel(p); + const signed char* tmpptr = tmp.channel(i / 8); + const signed char* kptr0 = kernel.channel(p / 2); int nn = inch * maxk; // inch always > 0 +#if __ARM_FEATURE_MATMUL_INT8 asm volatile( - "ld1 {v24.16b}, [%3], #16 \n" // _w0123_l - "eor v0.16b, v0.16b, v0.16b \n" "eor v1.16b, v1.16b, v1.16b \n" + "eor v2.16b, v2.16b, v2.16b \n" + "eor v3.16b, v3.16b, v3.16b \n" + "eor v4.16b, v4.16b, v4.16b \n" + "eor v5.16b, v5.16b, v5.16b \n" + "eor v6.16b, v6.16b, v6.16b \n" + "eor v7.16b, v7.16b, v7.16b \n" + "eor v8.16b, v8.16b, v8.16b \n" + "eor v9.16b, v9.16b, v9.16b \n" + "eor v10.16b, v10.16b, v10.16b \n" + "eor v11.16b, v11.16b, v11.16b \n" + "eor v12.16b, v12.16b, v12.16b \n" + "eor v13.16b, v13.16b, v13.16b \n" + "eor v14.16b, v14.16b, v14.16b \n" + "eor v15.16b, v15.16b, v15.16b \n" + + "0: \n" - "ld1 {v16.16b}, [%2], #16 \n" // _val0123_l + "ld1 {v16.16b, v17.16b, v18.16b, v19.16b}, [%3], #64 \n" // _val0 _val1 _val1 _val3 + "ld1 {v20.16b, v21.16b, v22.16b, v23.16b}, [%4], #64 \n" // _w01 _w23 _w45 _w67 + + "smmla v0.4s, v16.16b, v20.16b \n" + "smmla v1.4s, v16.16b, v21.16b \n" + "smmla v2.4s, v17.16b, v20.16b \n" + "smmla v3.4s, v17.16b, v21.16b \n" + "smmla v4.4s, v18.16b, v20.16b \n" + "smmla v5.4s, v18.16b, v21.16b \n" + "smmla v6.4s, v19.16b, v20.16b \n" + "smmla v7.4s, v19.16b, v21.16b \n" + + "subs %w2, %w2, #1 \n" + + "smmla v8.4s, v16.16b, v22.16b \n" + "smmla v9.4s, v16.16b, v23.16b \n" + "smmla v10.4s, v17.16b, v22.16b \n" + "smmla v11.4s, v17.16b, v23.16b \n" + "smmla v12.4s, v18.16b, v22.16b \n" + "smmla v13.4s, v18.16b, v23.16b \n" + "smmla v14.4s, v19.16b, v22.16b \n" + "smmla v15.4s, v19.16b, v23.16b \n" + "bne 0b \n" + + "st2 {v0.2d, v1.2d}, [%0], #32 \n" + "st2 {v2.2d, v3.2d}, [%0], #32 \n" + "st2 {v4.2d, v5.2d}, [%0], #32 \n" + "st2 {v6.2d, v7.2d}, [%0], #32 \n" + "st2 {v8.2d, v9.2d}, [%1], #32 \n" + "st2 {v10.2d, v11.2d}, [%1], #32 \n" + "st2 {v12.2d, v13.2d}, [%1], #32 \n" + "st2 {v14.2d, v15.2d}, [%1], #32 \n" + : "=r"(outptr0), + "=r"(outptr1), + "=r"(nn), + "=r"(tmpptr), + "=r"(kptr0) + : "0"(outptr0), + "1"(outptr1), + "2"(nn), + "3"(tmpptr), + "4"(kptr0) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); +#else // __ARM_FEATURE_MATMUL_INT8 + asm volatile( + "eor v0.16b, v0.16b, v0.16b \n" + "eor v1.16b, v1.16b, v1.16b \n" "eor v2.16b, v2.16b, v2.16b \n" "eor v3.16b, v3.16b, v3.16b \n" "eor v4.16b, v4.16b, v4.16b \n" @@ -347,97 +418,405 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b "0: \n" - "ld1 {v17.16b}, [%2], #16 \n" // _val4567_l + "ld1 {v16.16b, v17.16b, v18.16b, v19.16b}, [%3], #64 \n" // _val0 _val1 _val2 _val3 + "ld1 {v20.16b, v21.16b, v22.16b, v23.16b}, [%4], #64 \n" // _w01 _w23 _w45 _w67 + + "sdot v0.4s, v20.16b, v16.4b[0] \n" + "sdot v1.4s, v20.16b, v16.4b[1] \n" + "sdot v2.4s, v20.16b, v16.4b[2] \n" + "sdot v3.4s, v20.16b, v16.4b[3] \n" + "sdot v4.4s, v20.16b, v17.4b[0] \n" + "sdot v5.4s, v20.16b, v17.4b[1] \n" + "sdot v6.4s, v20.16b, v17.4b[2] \n" + "sdot v7.4s, v20.16b, v17.4b[3] \n" + + "sdot v0.4s, v21.16b, v18.4b[0] \n" + "sdot v1.4s, v21.16b, v18.4b[1] \n" + "sdot v2.4s, v21.16b, v18.4b[2] \n" + "sdot v3.4s, v21.16b, v18.4b[3] \n" + "sdot v4.4s, v21.16b, v19.4b[0] \n" + "sdot v5.4s, v21.16b, v19.4b[1] \n" + "sdot v6.4s, v21.16b, v19.4b[2] \n" + "sdot v7.4s, v21.16b, v19.4b[3] \n" + + "subs %w2, %w2, #1 \n" + + "sdot v8.4s, v22.16b, v16.4b[0] \n" + "sdot v9.4s, v22.16b, v16.4b[1] \n" + "sdot v10.4s, v22.16b, v16.4b[2] \n" + "sdot v11.4s, v22.16b, v16.4b[3] \n" + "sdot v12.4s, v22.16b, v17.4b[0] \n" + "sdot v13.4s, v22.16b, v17.4b[1] \n" + "sdot v14.4s, v22.16b, v17.4b[2] \n" + "sdot v15.4s, v22.16b, v17.4b[3] \n" + + "sdot v8.4s, v23.16b, v18.4b[0] \n" + "sdot v9.4s, v23.16b, v18.4b[1] \n" + "sdot v10.4s, v23.16b, v18.4b[2] \n" + "sdot v11.4s, v23.16b, v18.4b[3] \n" + "sdot v12.4s, v23.16b, v19.4b[0] \n" + "sdot v13.4s, v23.16b, v19.4b[1] \n" + "sdot v14.4s, v23.16b, v19.4b[2] \n" + "sdot v15.4s, v23.16b, v19.4b[3] \n" + + "bne 0b \n" + + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%0], #64 \n" + "st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%0], #64 \n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%1], #64 \n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%1], #64 \n" + : "=r"(outptr0), + "=r"(outptr1), + "=r"(nn), + "=r"(tmpptr), + "=r"(kptr0) + : "0"(outptr0), + "1"(outptr1), + "2"(nn), + "3"(tmpptr), + "4"(kptr0) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); +#endif // __ARM_FEATURE_MATMUL_INT8 + } + for (; i + 3 < size; i += 4) + { + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4); + const signed char* kptr0 = kernel.channel(p / 2); + + int nn = inch * maxk; // inch always > 0 + +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + int32x4_t _sum4 = vdupq_n_s32(0); + int32x4_t _sum5 = vdupq_n_s32(0); + int32x4_t _sum6 = vdupq_n_s32(0); + int32x4_t _sum7 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x16_t _val0 = vld1q_s8(tmpptr); + int8x16_t _val1 = vld1q_s8(tmpptr + 16); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + int8x16_t _w45 = vld1q_s8(kptr0 + 32); + int8x16_t _w67 = vld1q_s8(kptr0 + 48); - "sdot v0.4s, v24.16b, v16.4b[0] \n" - "sdot v1.4s, v24.16b, v16.4b[1] \n" - "sdot v2.4s, v24.16b, v16.4b[2] \n" - "sdot v3.4s, v24.16b, v16.4b[3] \n" + _sum0 = vmmlaq_s32(_sum0, _val0, _w01); + _sum1 = vmmlaq_s32(_sum1, _val0, _w23); + _sum2 = vmmlaq_s32(_sum2, _val1, _w01); + _sum3 = vmmlaq_s32(_sum3, _val1, _w23); - "ld1 {v18.16b}, [%2], #16 \n" // _val891011_l + _sum4 = vmmlaq_s32(_sum4, _val0, _w45); + _sum5 = vmmlaq_s32(_sum5, _val0, _w67); + _sum6 = vmmlaq_s32(_sum6, _val1, _w45); + _sum7 = vmmlaq_s32(_sum7, _val1, _w67); - "sdot v4.4s, v24.16b, v17.4b[0] \n" - "sdot v5.4s, v24.16b, v17.4b[1] \n" - "sdot v6.4s, v24.16b, v17.4b[2] \n" - "sdot v7.4s, v24.16b, v17.4b[3] \n" + tmpptr += 32; + kptr0 += 64; + } - "ld1 {v19.16b}, [%2], #16 \n" // _val12131415_l + int64x2x2_t _sum01; + _sum01.val[0] = vreinterpretq_s64_s32(_sum0); + _sum01.val[1] = vreinterpretq_s64_s32(_sum1); - "sdot v8.4s, v24.16b, v18.4b[0] \n" - "sdot v9.4s, v24.16b, v18.4b[1] \n" + int64x2x2_t _sum23; + _sum23.val[0] = vreinterpretq_s64_s32(_sum2); + _sum23.val[1] = vreinterpretq_s64_s32(_sum3); - "ld1 {v25.16b}, [%3], #16 \n" // _w0123_h + int64x2x2_t _sum45; + _sum45.val[0] = vreinterpretq_s64_s32(_sum4); + _sum45.val[1] = vreinterpretq_s64_s32(_sum5); - "sdot v10.4s, v24.16b, v18.4b[2] \n" - "sdot v11.4s, v24.16b, v18.4b[3] \n" + int64x2x2_t _sum67; + _sum67.val[0] = vreinterpretq_s64_s32(_sum6); + _sum67.val[1] = vreinterpretq_s64_s32(_sum7); - "ld1 {v20.16b}, [%2], #16 \n" // _val0123_h + vst2q_s64((int64_t*)outptr0, _sum01); + vst2q_s64((int64_t*)(outptr0 + 8), _sum23); - "sdot v12.4s, v24.16b, v19.4b[0] \n" - "sdot v13.4s, v24.16b, v19.4b[1] \n" - "sdot v14.4s, v24.16b, v19.4b[2] \n" - "sdot v15.4s, v24.16b, v19.4b[3] \n" + vst2q_s64((int64_t*)outptr1, _sum45); + vst2q_s64((int64_t*)(outptr1 + 8), _sum67); - "ld1 {v21.16b}, [%2], #16 \n" // _val4567_h + outptr0 += 16; + outptr1 += 16; +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + int32x4_t _sum4 = vdupq_n_s32(0); + int32x4_t _sum5 = vdupq_n_s32(0); + int32x4_t _sum6 = vdupq_n_s32(0); + int32x4_t _sum7 = vdupq_n_s32(0); - "sdot v0.4s, v25.16b, v20.4b[0] \n" - "sdot v1.4s, v25.16b, v20.4b[1] \n" - "sdot v2.4s, v25.16b, v20.4b[2] \n" - "sdot v3.4s, v25.16b, v20.4b[3] \n" + for (int j = 0; j < nn; j++) + { + int8x16_t _val0123_l = vld1q_s8(tmpptr); + int8x16_t _val0123_h = vld1q_s8(tmpptr + 16); + int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); + int8x16_t _w4567_l = vld1q_s8(kptr0 + 32); + int8x16_t _w4567_h = vld1q_s8(kptr0 + 48); - "ld1 {v22.16b}, [%2], #16 \n" // _val891011_h + _sum0 = vdotq_laneq_s32(_sum0, _w0123_l, _val0123_l, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0123_l, _val0123_l, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w0123_l, _val0123_l, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w0123_l, _val0123_l, 3); + _sum0 = vdotq_laneq_s32(_sum0, _w0123_h, _val0123_h, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0123_h, _val0123_h, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w0123_h, _val0123_h, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w0123_h, _val0123_h, 3); - "sdot v4.4s, v25.16b, v21.4b[0] \n" - "sdot v5.4s, v25.16b, v21.4b[1] \n" - "sdot v6.4s, v25.16b, v21.4b[2] \n" - "sdot v7.4s, v25.16b, v21.4b[3] \n" + _sum4 = vdotq_laneq_s32(_sum4, _w4567_l, _val0123_l, 0); + _sum5 = vdotq_laneq_s32(_sum5, _w4567_l, _val0123_l, 1); + _sum6 = vdotq_laneq_s32(_sum6, _w4567_l, _val0123_l, 2); + _sum7 = vdotq_laneq_s32(_sum7, _w4567_l, _val0123_l, 3); + _sum4 = vdotq_laneq_s32(_sum4, _w4567_h, _val0123_h, 0); + _sum5 = vdotq_laneq_s32(_sum5, _w4567_h, _val0123_h, 1); + _sum6 = vdotq_laneq_s32(_sum6, _w4567_h, _val0123_h, 2); + _sum7 = vdotq_laneq_s32(_sum7, _w4567_h, _val0123_h, 3); - "ld1 {v23.16b}, [%2], #16 \n" // _val12131415_h + tmpptr += 32; + kptr0 += 64; + } - "sdot v8.4s, v25.16b, v22.4b[0] \n" - "sdot v9.4s, v25.16b, v22.4b[1] \n" + vst1q_s32(outptr0, _sum0); + vst1q_s32(outptr0 + 4, _sum1); + vst1q_s32(outptr0 + 8, _sum2); + vst1q_s32(outptr0 + 12, _sum3); + vst1q_s32(outptr1, _sum4); + vst1q_s32(outptr1 + 4, _sum5); + vst1q_s32(outptr1 + 8, _sum6); + vst1q_s32(outptr1 + 12, _sum7); + outptr0 += 16; + outptr1 += 16; +#endif // __ARM_FEATURE_MATMUL_INT8 + } + for (; i + 1 < size; i += 2) + { + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2); + const signed char* kptr0 = kernel.channel(p / 2); - "ld1 {v24.16b}, [%3], #16 \n" // _w0123_l + int nn = inch * maxk; // inch always > 0 - "sdot v10.4s, v25.16b, v22.4b[2] \n" - "sdot v11.4s, v25.16b, v22.4b[3] \n" +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); - "ld1 {v16.16b}, [%2], #16 \n" // _val0123_l + for (int j = 0; j < nn; j++) + { + int8x16_t _val = vld1q_s8(tmpptr); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + int8x16_t _w45 = vld1q_s8(kptr0 + 32); + int8x16_t _w67 = vld1q_s8(kptr0 + 48); - "sdot v12.4s, v25.16b, v23.4b[0] \n" - "sdot v13.4s, v25.16b, v23.4b[1] \n" + _sum0 = vmmlaq_s32(_sum0, _val, _w01); + _sum1 = vmmlaq_s32(_sum1, _val, _w23); + _sum2 = vmmlaq_s32(_sum2, _val, _w45); + _sum3 = vmmlaq_s32(_sum3, _val, _w67); - "subs %w1, %w1, #1 \n" + tmpptr += 16; + kptr0 += 64; + } - "sdot v14.4s, v25.16b, v23.4b[2] \n" - "sdot v15.4s, v25.16b, v23.4b[3] \n" + int64x2x2_t _sum01; + _sum01.val[0] = vreinterpretq_s64_s32(_sum0); + _sum01.val[1] = vreinterpretq_s64_s32(_sum1); - "bne 0b \n" + int64x2x2_t _sum23; + _sum23.val[0] = vreinterpretq_s64_s32(_sum2); + _sum23.val[1] = vreinterpretq_s64_s32(_sum3); - "sub %2, %2, #16 \n" - "sub %3, %3, #16 \n" + vst2q_s64((int64_t*)outptr0, _sum01); + vst2q_s64((int64_t*)outptr1, _sum23); - "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%0], #64 \n" - "st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%0], #64 \n" - "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%0], #64 \n" - "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%0], #64 \n" - : "=r"(outptr0), - "=r"(nn), - "=r"(tmpptr), - "=r"(kptr0) - : "0"(outptr0), - "1"(nn), - "2"(tmpptr), - "3"(kptr0) - : "memory", "x4", "x5", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + outptr0 += 8; + outptr1 += 8; +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x16_t _val01_l_h = vld1q_s8(tmpptr); + int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); + int8x16_t _w4567_l = vld1q_s8(kptr0 + 32); + int8x16_t _w4567_h = vld1q_s8(kptr0 + 48); + + _sum0 = vdotq_laneq_s32(_sum0, _w0123_l, _val01_l_h, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w0123_l, _val01_l_h, 1); + _sum0 = vdotq_laneq_s32(_sum0, _w0123_h, _val01_l_h, 2); + _sum1 = vdotq_laneq_s32(_sum1, _w0123_h, _val01_l_h, 3); + + _sum2 = vdotq_laneq_s32(_sum2, _w4567_l, _val01_l_h, 0); + _sum3 = vdotq_laneq_s32(_sum3, _w4567_l, _val01_l_h, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w4567_h, _val01_l_h, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w4567_h, _val01_l_h, 3); + + tmpptr += 16; + kptr0 += 64; + } + + vst1q_s32(outptr0, _sum0); + vst1q_s32(outptr0 + 4, _sum1); + vst1q_s32(outptr1, _sum2); + vst1q_s32(outptr1 + 4, _sum3); + outptr0 += 8; + outptr1 += 8; +#endif // __ARM_FEATURE_MATMUL_INT8 } + for (; i < size; i++) + { + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + const signed char* kptr0 = kernel.channel(p / 2); + + int nn = inch * maxk; // inch always > 0 + +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum23 = vdupq_n_s32(0); + int32x4_t _sum45 = vdupq_n_s32(0); + int32x4_t _sum67 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x8_t _val0 = vld1_s8(tmpptr); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + int8x16_t _w45 = vld1q_s8(kptr0 + 32); + int8x16_t _w67 = vld1q_s8(kptr0 + 48); + + int8x16_t _val = vcombine_s8(_val0, _val0); + + _sum01 = vdotq_s32(_sum01, _val, _w01); + _sum23 = vdotq_s32(_sum23, _val, _w23); + _sum45 = vdotq_s32(_sum45, _val, _w45); + _sum67 = vdotq_s32(_sum67, _val, _w67); + + tmpptr += 8; + kptr0 += 64; + } + + int32x4_t _s0123 = vpaddq_s32(_sum01, _sum23); + int32x4_t _s4567 = vpaddq_s32(_sum45, _sum67); + + vst1q_s32(outptr0, _s0123); + vst1q_s32(outptr1, _s4567); + outptr0 += 4; + outptr1 += 4; +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x8_t _val0_l_h = vld1_s8(tmpptr); + + int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); + int8x16_t _w4567_l = vld1q_s8(kptr0 + 32); + int8x16_t _w4567_h = vld1q_s8(kptr0 + 48); + + _sum0 = vdotq_lane_s32(_sum0, _w0123_l, _val0_l_h, 0); + _sum0 = vdotq_lane_s32(_sum0, _w0123_h, _val0_l_h, 1); + _sum1 = vdotq_lane_s32(_sum1, _w4567_l, _val0_l_h, 0); + _sum1 = vdotq_lane_s32(_sum1, _w4567_h, _val0_l_h, 1); + + tmpptr += 8; + kptr0 += 64; + } + + vst1q_s32(outptr0, _sum0); + vst1q_s32(outptr1, _sum1); + outptr0 += 4; + outptr1 += 4; +#endif // __ARM_FEATURE_MATMUL_INT8 + } + } +#else // __ARM_FEATURE_DOTPROD + int remain_outch_start = 0; +#endif // __ARM_FEATURE_DOTPROD + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = remain_outch_start; p < outch; p++) + { + int* outptr0 = top_blob.channel(p); + + int i = 0; +#if __aarch64__ +#if __ARM_FEATURE_DOTPROD for (; i + 7 < size; i += 8) { - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8); - const signed char* kptr0 = kernel.channel(p); + const signed char* tmpptr = tmp.channel(i / 8); + const signed char* kptr0 = kernel.channel(p / 2 + p % 2); int nn = inch * maxk; // inch always > 0 +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + int32x4_t _sum4 = vdupq_n_s32(0); + int32x4_t _sum5 = vdupq_n_s32(0); + int32x4_t _sum6 = vdupq_n_s32(0); + int32x4_t _sum7 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x16_t _val0 = vld1q_s8(tmpptr); + int8x16_t _val1 = vld1q_s8(tmpptr + 16); + int8x16_t _val2 = vld1q_s8(tmpptr + 32); + int8x16_t _val3 = vld1q_s8(tmpptr + 48); + + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + + _sum0 = vmmlaq_s32(_sum0, _val0, _w01); + _sum1 = vmmlaq_s32(_sum1, _val0, _w23); + _sum2 = vmmlaq_s32(_sum2, _val1, _w01); + _sum3 = vmmlaq_s32(_sum3, _val1, _w23); + _sum4 = vmmlaq_s32(_sum4, _val2, _w01); + _sum5 = vmmlaq_s32(_sum5, _val2, _w23); + _sum6 = vmmlaq_s32(_sum6, _val3, _w01); + _sum7 = vmmlaq_s32(_sum7, _val3, _w23); + + tmpptr += 64; + kptr0 += 32; + } + + int64x2x2_t _sum01; + _sum01.val[0] = vreinterpretq_s64_s32(_sum0); + _sum01.val[1] = vreinterpretq_s64_s32(_sum1); + + int64x2x2_t _sum23; + _sum23.val[0] = vreinterpretq_s64_s32(_sum2); + _sum23.val[1] = vreinterpretq_s64_s32(_sum3); + + int64x2x2_t _sum45; + _sum45.val[0] = vreinterpretq_s64_s32(_sum4); + _sum45.val[1] = vreinterpretq_s64_s32(_sum5); + + int64x2x2_t _sum67; + _sum67.val[0] = vreinterpretq_s64_s32(_sum6); + _sum67.val[1] = vreinterpretq_s64_s32(_sum7); + + vst2q_s64((int64_t*)outptr0, _sum01); + vst2q_s64((int64_t*)(outptr0 + 8), _sum23); + vst2q_s64((int64_t*)(outptr0 + 16), _sum45); + vst2q_s64((int64_t*)(outptr0 + 24), _sum67); + + outptr0 += 32; +#else // __ARM_FEATURE_MATMUL_INT8 int32x4_t _sum0 = vdupq_n_s32(0); int32x4_t _sum1 = vdupq_n_s32(0); int32x4_t _sum2 = vdupq_n_s32(0); @@ -451,8 +830,11 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b { int8x16_t _val0123_l = vld1q_s8(tmpptr); int8x16_t _val4567_l = vld1q_s8(tmpptr + 16); + int8x16_t _val0123_h = vld1q_s8(tmpptr + 32); + int8x16_t _val4567_h = vld1q_s8(tmpptr + 48); int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); _sum0 = vdotq_laneq_s32(_sum0, _w0123_l, _val0123_l, 0); _sum1 = vdotq_laneq_s32(_sum1, _w0123_l, _val0123_l, 1); @@ -463,11 +845,6 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b _sum6 = vdotq_laneq_s32(_sum6, _w0123_l, _val4567_l, 2); _sum7 = vdotq_laneq_s32(_sum7, _w0123_l, _val4567_l, 3); - int8x16_t _val0123_h = vld1q_s8(tmpptr + 32); - int8x16_t _val4567_h = vld1q_s8(tmpptr + 48); - - int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); - _sum0 = vdotq_laneq_s32(_sum0, _w0123_h, _val0123_h, 0); _sum1 = vdotq_laneq_s32(_sum1, _w0123_h, _val0123_h, 1); _sum2 = vdotq_laneq_s32(_sum2, _w0123_h, _val0123_h, 2); @@ -490,20 +867,56 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b vst1q_s32(outptr0 + 24, _sum6); vst1q_s32(outptr0 + 28, _sum7); outptr0 += 32; +#endif // __ARM_FEATURE_MATMUL_INT8 } -#endif +#endif // __ARM_FEATURE_DOTPROD for (; i + 3 < size; i += 4) { #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4); + const signed char* kptr0 = kernel.channel(p / 2 + p % 2); #else const signed char* tmpptr = tmp.channel(i / 4); -#endif const signed char* kptr0 = kernel.channel(p); +#endif int nn = inch * maxk; // inch always > 0 -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x16_t _val0 = vld1q_s8(tmpptr); + int8x16_t _val1 = vld1q_s8(tmpptr + 16); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + + _sum0 = vmmlaq_s32(_sum0, _val0, _w01); + _sum1 = vmmlaq_s32(_sum1, _val0, _w23); + _sum2 = vmmlaq_s32(_sum2, _val1, _w01); + _sum3 = vmmlaq_s32(_sum3, _val1, _w23); + + tmpptr += 32; + kptr0 += 32; + } + + int64x2x2_t _sum01; + _sum01.val[0] = vreinterpretq_s64_s32(_sum0); + _sum01.val[1] = vreinterpretq_s64_s32(_sum1); + + int64x2x2_t _sum23; + _sum23.val[0] = vreinterpretq_s64_s32(_sum2); + _sum23.val[1] = vreinterpretq_s64_s32(_sum3); + + vst2q_s64((int64_t*)outptr0, _sum01); + vst2q_s64((int64_t*)(outptr0 + 8), _sum23); + + outptr0 += 16; +#elif __ARM_FEATURE_DOTPROD int32x4_t _sum0 = vdupq_n_s32(0); int32x4_t _sum1 = vdupq_n_s32(0); int32x4_t _sum2 = vdupq_n_s32(0); @@ -512,16 +925,15 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b for (int j = 0; j < nn; j++) { int8x16_t _val0123_l = vld1q_s8(tmpptr); + int8x16_t _val0123_h = vld1q_s8(tmpptr + 16); int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); _sum0 = vdotq_laneq_s32(_sum0, _w0123_l, _val0123_l, 0); _sum1 = vdotq_laneq_s32(_sum1, _w0123_l, _val0123_l, 1); _sum2 = vdotq_laneq_s32(_sum2, _w0123_l, _val0123_l, 2); _sum3 = vdotq_laneq_s32(_sum3, _w0123_l, _val0123_l, 3); - int8x16_t _val0123_h = vld1q_s8(tmpptr + 16); - int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); - _sum0 = vdotq_laneq_s32(_sum0, _w0123_h, _val0123_h, 0); _sum1 = vdotq_laneq_s32(_sum1, _w0123_h, _val0123_h, 1); _sum2 = vdotq_laneq_s32(_sum2, _w0123_h, _val0123_h, 2); @@ -737,19 +1149,45 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b { #if __aarch64__ #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2); + const signed char* kptr0 = kernel.channel(p / 2 + p % 2); #else const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); + const signed char* kptr0 = kernel.channel(p); #endif #else const signed char* tmpptr = tmp.channel(i / 2); -#endif const signed char* kptr0 = kernel.channel(p); +#endif int nn = inch * maxk; // inch always > 0 #if __aarch64__ -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x16_t _val = vld1q_s8(tmpptr); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); + + _sum0 = vmmlaq_s32(_sum0, _val, _w01); + _sum1 = vmmlaq_s32(_sum1, _val, _w23); + + tmpptr += 16; + kptr0 += 32; + } + + int64x2x2_t _sum01; + _sum01.val[0] = vreinterpretq_s64_s32(_sum0); + _sum01.val[1] = vreinterpretq_s64_s32(_sum1); + + vst2q_s64((int64_t*)outptr0, _sum01); + + outptr0 += 8; +#elif __ARM_FEATURE_DOTPROD int32x4_t _sum0 = vdupq_n_s32(0); int32x4_t _sum1 = vdupq_n_s32(0); @@ -757,12 +1195,11 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b { int8x16_t _val01_l_h = vld1q_s8(tmpptr); int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); _sum0 = vdotq_laneq_s32(_sum0, _w0123_l, _val01_l_h, 0); _sum1 = vdotq_laneq_s32(_sum1, _w0123_l, _val01_l_h, 1); - int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); - _sum0 = vdotq_laneq_s32(_sum0, _w0123_h, _val01_l_h, 2); _sum1 = vdotq_laneq_s32(_sum1, _w0123_h, _val01_l_h, 3); @@ -1007,30 +1444,52 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b { #if __aarch64__ #if __ARM_FEATURE_DOTPROD - const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + const signed char* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2); + const signed char* kptr0 = kernel.channel(p / 2 + p % 2); #else const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); + const signed char* kptr0 = kernel.channel(p); #endif #else const signed char* tmpptr = tmp.channel(i / 2 + i % 2); -#endif const signed char* kptr0 = kernel.channel(p); +#endif int nn = inch * maxk; // inch always > 0 -#if __ARM_FEATURE_DOTPROD - int32x4_t _sum0 = vdupq_n_s32(0); +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum23 = vdupq_n_s32(0); for (int j = 0; j < nn; j++) { - int8x8_t _val0_l_h = vld1_s8(tmpptr); + int8x8_t _val0 = vld1_s8(tmpptr); + int8x16_t _w01 = vld1q_s8(kptr0); + int8x16_t _w23 = vld1q_s8(kptr0 + 16); - int8x16_t _w0123_l = vld1q_s8(kptr0); + int8x16_t _val = vcombine_s8(_val0, _val0); - _sum0 = vdotq_lane_s32(_sum0, _w0123_l, _val0_l_h, 0); + _sum01 = vdotq_s32(_sum01, _val, _w01); + _sum23 = vdotq_s32(_sum23, _val, _w23); + tmpptr += 8; + kptr0 += 32; + } + + int32x4_t _s0123 = vpaddq_s32(_sum01, _sum23); + + vst1q_s32(outptr0, _s0123); + outptr0 += 4; +#elif __ARM_FEATURE_DOTPROD + int32x4_t _sum0 = vdupq_n_s32(0); + + for (int j = 0; j < nn; j++) + { + int8x8_t _val0_l_h = vld1_s8(tmpptr); + int8x16_t _w0123_l = vld1q_s8(kptr0); int8x16_t _w0123_h = vld1q_s8(kptr0 + 16); + _sum0 = vdotq_lane_s32(_sum0, _w0123_l, _val0_l_h, 0); _sum0 = vdotq_lane_s32(_sum0, _w0123_h, _val0_l_h, 1); tmpptr += 8; @@ -1118,12 +1577,22 @@ static void im2col_sgemm_pack8to4_int8_neon(const Mat& bottom_im2col, Mat& top_b static void convolution_im2col_sgemm_transform_kernel_pack8to4_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) { -#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __ARM_NEON && __aarch64__ && !__ARM_FEATURE_DOTPROD +#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + convolution_im2col_sgemm_transform_kernel_pack8to4_int8_neon_i8mm(_kernel, kernel_tm, inch, outch, kernel_w, kernel_h); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD if (ncnn::cpu_support_arm_asimddp()) { convolution_im2col_sgemm_transform_kernel_pack8to4_int8_neon_asimddp(_kernel, kernel_tm, inch, outch, kernel_w, kernel_h); return; } +#endif #endif const int maxk = kernel_w * kernel_h; @@ -1131,27 +1600,45 @@ static void convolution_im2col_sgemm_transform_kernel_pack8to4_int8_neon(const M // interleave // src = maxk-inch-outch // dst = 8a-4b-maxk-inch/8a-outch/4b - // dst = 4a-4b-2-maxk-inch/8a-outch/4b (arm82) + // dst = 4a-4b-2aa-2bb-maxk-inch/8a-outch/8b (arm82) + // dst = 8a-8b-maxk-inch/8a-outch/8b (arm84) Mat kernel = _kernel.reshape(maxk, inch, outch); +#if __ARM_FEATURE_DOTPROD + if (outch >= 8) + kernel_tm.create(64 * maxk, inch / 8, outch / 8 + (outch % 8) / 4, (size_t)1u); + else + kernel_tm.create(32 * maxk, inch / 8, outch / 4, (size_t)1u); +#else kernel_tm.create(32 * maxk, inch / 8, outch / 4, (size_t)1u); +#endif - for (int q = 0; q + 3 < outch; q += 4) + int q = 0; +#if __ARM_FEATURE_DOTPROD + for (; q + 7 < outch; q += 8) { - signed char* g00 = kernel_tm.channel(q / 4); + signed char* g00 = kernel_tm.channel(q / 8); for (int p = 0; p + 7 < inch; p += 8) { for (int k = 0; k < maxk; k++) { -#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (int i = 0; i < 8; i++) + { + for (int j = 0; j < 8; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } +#else // __ARM_FEATURE_MATMUL_INT8 for (int i = 0; i < 4; i++) { for (int j = 0; j < 4; j++) { const signed char* k00 = kernel.channel(q + i).row(p + j); - g00[0] = k00[k]; - g00++; } } @@ -1160,9 +1647,71 @@ static void convolution_im2col_sgemm_transform_kernel_pack8to4_int8_neon(const M for (int j = 4; j < 8; j++) { const signed char* k00 = kernel.channel(q + i).row(p + j); - g00[0] = k00[k]; + g00++; + } + } + for (int i = 4; i < 8; i++) + { + for (int j = 0; j < 4; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } + for (int i = 4; i < 8; i++) + { + for (int j = 4; j < 8; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } +#endif // __ARM_FEATURE_MATMUL_INT8 + } + } + } +#endif // __ARM_FEATURE_DOTPROD + for (; q + 3 < outch; q += 4) + { +#if __ARM_FEATURE_DOTPROD + signed char* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4); +#else + signed char* g00 = kernel_tm.channel(q / 4); +#endif + for (int p = 0; p + 7 < inch; p += 8) + { + for (int k = 0; k < maxk; k++) + { +#if __ARM_FEATURE_MATMUL_INT8 + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 8; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } +#elif __ARM_FEATURE_DOTPROD + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 4; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; + g00++; + } + } + for (int i = 0; i < 4; i++) + { + for (int j = 4; j < 8; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + g00[0] = k00[k]; g00++; } } @@ -1172,9 +1721,7 @@ static void convolution_im2col_sgemm_transform_kernel_pack8to4_int8_neon(const M for (int j = 0; j < 8; j++) { const signed char* k00 = kernel.channel(q + i).row(p + j); - g00[0] = k00[k]; - g00++; } } diff --git a/tests/test_convolution.cpp b/tests/test_convolution.cpp index 0568fed5e..5f7f5d209 100644 --- a/tests/test_convolution.cpp +++ b/tests/test_convolution.cpp @@ -413,7 +413,8 @@ static int test_convolution_1() || test_convolution_int8(4, 8, 16, 24, 3, 1, 1, 1, 1) || test_convolution_int8(4, 20, 16, 24, 3, 1, 1, 1, 0) || test_convolution_int8(6, 7, 64, 64, 3, 1, 2, 0, 1) - || test_convolution_int8(25, 33, 16, 15, 3, 1, 1, 1, 0); + || test_convolution_int8(25, 33, 16, 15, 3, 1, 1, 1, 0) + || test_convolution_int8(7, 7, 15, 12, 3, 1, 1, 1, 0); } #endif // NCNN_INT8