From e9b8f0a6ef5d93fa5958da6533372efe8fc125a8 Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 11 Jan 2022 15:50:00 +0800 Subject: [PATCH] x86 avx2 optimization for convolution gemm int8 (#3489) --- src/layer/x86/convolution_1x1_int8.h | 70 + src/layer/x86/convolution_1x1_pack1to4_int8.h | 83 + src/layer/x86/convolution_1x1_pack8to1_int8.h | 65 + src/layer/x86/convolution_1x1_pack8to4_int8.h | 2 +- src/layer/x86/convolution_3x3_pack1to4_int8.h | 147 ++ src/layer/x86/convolution_7x7_pack1to4_int8.h | 80 + src/layer/x86/convolution_sgemm_int8.h | 1994 ++++++++--------- .../x86/convolution_sgemm_pack1to4_int8.h | 750 +++++++ .../x86/convolution_sgemm_pack8to1_int8.h | 839 +++++++ .../x86/convolution_sgemm_pack8to4_int8.h | 231 ++ src/layer/x86/convolution_x86.cpp | 319 +-- src/layer/x86/convolution_x86.h | 5 +- src/layer/x86/x86_usability.h | 26 +- 13 files changed, 3361 insertions(+), 1250 deletions(-) create mode 100644 src/layer/x86/convolution_1x1_pack1to4_int8.h create mode 100644 src/layer/x86/convolution_1x1_pack8to1_int8.h create mode 100644 src/layer/x86/convolution_3x3_pack1to4_int8.h create mode 100644 src/layer/x86/convolution_7x7_pack1to4_int8.h create mode 100644 src/layer/x86/convolution_sgemm_pack1to4_int8.h create mode 100644 src/layer/x86/convolution_sgemm_pack8to1_int8.h diff --git a/src/layer/x86/convolution_1x1_int8.h b/src/layer/x86/convolution_1x1_int8.h index d35bf282b..e03639b75 100644 --- a/src/layer/x86/convolution_1x1_int8.h +++ b/src/layer/x86/convolution_1x1_int8.h @@ -12,6 +12,76 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. +static void conv1x1s1_sgemm_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +{ + int w = bottom_blob.w; + int h = bottom_blob.h; + const int size = w * h; + + Mat bottom_im2col = bottom_blob; + bottom_im2col.w = size; + bottom_im2col.h = 1; + + im2col_sgemm_int8_sse(bottom_im2col, top_blob, kernel, opt); +} + +static void conv1x1s2_sgemm_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +{ + int w = bottom_blob.w; + int channels = bottom_blob.c; + size_t elemsize = bottom_blob.elemsize; + int elempack = bottom_blob.elempack; + + int outw = top_blob.w; + int outh = top_blob.h; + + const int tailstep = w - 2 * outw + w; + + Mat bottom_blob_shrinked; + bottom_blob_shrinked.create(outw, outh, channels, elemsize, elempack, opt.workspace_allocator); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < channels; p++) + { + const signed char* r0 = bottom_blob.channel(p); + signed char* outptr = bottom_blob_shrinked.channel(p); + + for (int i = 0; i < outh; i++) + { + int j = 0; + for (; j + 3 < outw; j += 4) + { + outptr[0] = r0[0]; + outptr[1] = r0[2]; + outptr[2] = r0[4]; + outptr[3] = r0[6]; + + r0 += 8; + outptr += 4; + } + for (; j + 1 < outw; j += 2) + { + outptr[0] = r0[0]; + outptr[1] = r0[2]; + + r0 += 4; + outptr += 2; + } + for (; j < outw; j++) + { + outptr[0] = r0[0]; + + r0 += 2; + outptr += 1; + } + + r0 += tailstep; + } + } + + conv1x1s1_sgemm_int8_sse(bottom_blob_shrinked, top_blob, kernel, opt); +} + static void conv1x1s1_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Option& opt) { int inch = bottom_blob.c; diff --git a/src/layer/x86/convolution_1x1_pack1to4_int8.h b/src/layer/x86/convolution_1x1_pack1to4_int8.h new file mode 100644 index 000000000..452a948c8 --- /dev/null +++ b/src/layer/x86/convolution_1x1_pack1to4_int8.h @@ -0,0 +1,83 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv1x1s1_sgemm_pack1to4_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +{ + int w = bottom_blob.w; + int h = bottom_blob.h; + const int size = w * h; + + Mat bottom_im2col = bottom_blob; + bottom_im2col.w = size; + bottom_im2col.h = 1; + + im2col_sgemm_pack1to4_int8_sse(bottom_im2col, top_blob, kernel, opt); +} + +static void conv1x1s2_sgemm_pack1to4_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +{ + int w = bottom_blob.w; + int channels = bottom_blob.c; + size_t elemsize = bottom_blob.elemsize; + int elempack = bottom_blob.elempack; + + int outw = top_blob.w; + int outh = top_blob.h; + + const int tailstep = w - 2 * outw + w; + + Mat bottom_blob_shrinked; + bottom_blob_shrinked.create(outw, outh, channels, elemsize, elempack, opt.workspace_allocator); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < channels; p++) + { + const signed char* r0 = bottom_blob.channel(p); + signed char* outptr = bottom_blob_shrinked.channel(p); + + for (int i = 0; i < outh; i++) + { + int j = 0; + for (; j + 3 < outw; j += 4) + { + outptr[0] = r0[0]; + outptr[1] = r0[2]; + outptr[2] = r0[4]; + outptr[3] = r0[6]; + + r0 += 8; + outptr += 4; + } + for (; j + 1 < outw; j += 2) + { + outptr[0] = r0[0]; + outptr[1] = r0[2]; + + r0 += 4; + outptr += 2; + } + for (; j < outw; j++) + { + outptr[0] = r0[0]; + + r0 += 2; + outptr += 1; + } + + r0 += tailstep; + } + } + + conv1x1s1_sgemm_pack1to4_int8_sse(bottom_blob_shrinked, top_blob, kernel, opt); +} diff --git a/src/layer/x86/convolution_1x1_pack8to1_int8.h b/src/layer/x86/convolution_1x1_pack8to1_int8.h new file mode 100644 index 000000000..f8f70525b --- /dev/null +++ b/src/layer/x86/convolution_1x1_pack8to1_int8.h @@ -0,0 +1,65 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv1x1s1_sgemm_pack8to1_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +{ + int w = bottom_blob.w; + int h = bottom_blob.h; + const int size = w * h; + + Mat bottom_im2col = bottom_blob; + bottom_im2col.w = size; + bottom_im2col.h = 1; + + im2col_sgemm_pack8to1_int8_sse(bottom_im2col, top_blob, kernel, opt); +} + +static void conv1x1s2_sgemm_pack8to1_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +{ + int w = bottom_blob.w; + int channels = bottom_blob.c; + size_t elemsize = bottom_blob.elemsize; + int elempack = bottom_blob.elempack; + + int outw = top_blob.w; + int outh = top_blob.h; + + const int tailstep = w - 2 * outw + w; + + Mat bottom_blob_shrinked; + bottom_blob_shrinked.create(outw, outh, channels, elemsize, elempack, opt.workspace_allocator); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < channels; p++) + { + const int64_t* r0 = bottom_blob.channel(p); + int64_t* outptr = bottom_blob_shrinked.channel(p); + + for (int i = 0; i < outh; i++) + { + int j = 0; + for (; j < outw; j++) + { + outptr[0] = r0[0]; + + r0 += 2; + outptr += 1; + } + + r0 += tailstep; + } + } + + conv1x1s1_sgemm_pack8to1_int8_sse(bottom_blob_shrinked, top_blob, kernel, opt); +} diff --git a/src/layer/x86/convolution_1x1_pack8to4_int8.h b/src/layer/x86/convolution_1x1_pack8to4_int8.h index 2efeabc14..9c1982341 100644 --- a/src/layer/x86/convolution_1x1_pack8to4_int8.h +++ b/src/layer/x86/convolution_1x1_pack8to4_int8.h @@ -25,7 +25,7 @@ static void conv1x1s1_sgemm_pack8to4_int8_sse(const Mat& bottom_blob, Mat& top_b im2col_sgemm_pack8to4_int8_sse(bottom_im2col, top_blob, kernel, opt); } -static void conv1x1s2_pack8to4_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +static void conv1x1s2_sgemm_pack8to4_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) { int w = bottom_blob.w; int channels = bottom_blob.c; diff --git a/src/layer/x86/convolution_3x3_pack1to4_int8.h b/src/layer/x86/convolution_3x3_pack1to4_int8.h new file mode 100644 index 000000000..2ec4eae71 --- /dev/null +++ b/src/layer/x86/convolution_3x3_pack1to4_int8.h @@ -0,0 +1,147 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv3x3s1_pack1to4_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +{ + int w = bottom_blob.w; + int inch = bottom_blob.c; + + int outw = top_blob.w; + int outh = top_blob.h; + const int size = outw * outh; + + const int maxk = 9; + + // im2col + Mat bottom_im2col(size, maxk, inch, 1u, 1, opt.workspace_allocator); + { + const int gap = w - outw; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < inch; p++) + { + const Mat img = bottom_blob.channel(p); + signed char* ptr = bottom_im2col.channel(p); + + for (int u = 0; u < 3; u++) + { + for (int v = 0; v < 3; v++) + { + const signed char* sptr = img.row(u) + v; + + for (int i = 0; i < outh; i++) + { + int j = 0; + for (; j + 3 < outw; j += 4) + { + ptr[0] = sptr[0]; + ptr[1] = sptr[1]; + ptr[2] = sptr[2]; + ptr[3] = sptr[3]; + + sptr += 4; + ptr += 4; + } + for (; j + 1 < outw; j += 2) + { + ptr[0] = sptr[0]; + ptr[1] = sptr[1]; + + sptr += 2; + ptr += 2; + } + for (; j < outw; j++) + { + ptr[0] = sptr[0]; + + sptr += 1; + ptr += 1; + } + + sptr += gap; + } + } + } + } + } + + im2col_sgemm_pack1to4_int8_sse(bottom_im2col, top_blob, kernel, opt); +} + +static void conv3x3s2_pack1to4_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +{ + int w = bottom_blob.w; + int inch = bottom_blob.c; + + int outw = top_blob.w; + int outh = top_blob.h; + const int size = outw * outh; + + const int maxk = 9; + + // im2col + Mat bottom_im2col(size, maxk, inch, 1u, 1, opt.workspace_allocator); + { + const int gap = w * 2 - outw * 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < inch; p++) + { + const Mat img = bottom_blob.channel(p); + signed char* ptr = bottom_im2col.channel(p); + + for (int u = 0; u < 3; u++) + { + for (int v = 0; v < 3; v++) + { + const signed char* sptr = img.row(u) + v; + + for (int i = 0; i < outh; i++) + { + int j = 0; + for (; j + 3 < outw; j += 4) + { + ptr[0] = sptr[0]; + ptr[1] = sptr[2]; + ptr[2] = sptr[4]; + ptr[3] = sptr[6]; + + sptr += 8; + ptr += 4; + } + for (; j + 1 < outw; j += 2) + { + ptr[0] = sptr[0]; + ptr[1] = sptr[2]; + + sptr += 4; + ptr += 2; + } + for (; j < outw; j++) + { + ptr[0] = sptr[0]; + + sptr += 2; + ptr += 1; + } + + sptr += gap; + } + } + } + } + } + + im2col_sgemm_pack1to4_int8_sse(bottom_im2col, top_blob, kernel, opt); +} diff --git a/src/layer/x86/convolution_7x7_pack1to4_int8.h b/src/layer/x86/convolution_7x7_pack1to4_int8.h new file mode 100644 index 000000000..7fa67fed0 --- /dev/null +++ b/src/layer/x86/convolution_7x7_pack1to4_int8.h @@ -0,0 +1,80 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv7x7s2_pack1to4_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +{ + int w = bottom_blob.w; + int inch = bottom_blob.c; + + int outw = top_blob.w; + int outh = top_blob.h; + const int size = outw * outh; + + const int maxk = 49; + + // im2col + Mat bottom_im2col(size, maxk, inch, 1u, 1, opt.workspace_allocator); + { + const int gap = w * 2 - outw * 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < inch; p++) + { + const Mat img = bottom_blob.channel(p); + signed char* ptr = bottom_im2col.channel(p); + + for (int u = 0; u < 7; u++) + { + for (int v = 0; v < 7; v++) + { + const signed char* sptr = img.row(u) + v; + + for (int i = 0; i < outh; i++) + { + int j = 0; + for (; j + 3 < outw; j += 4) + { + ptr[0] = sptr[0]; + ptr[1] = sptr[2]; + ptr[2] = sptr[4]; + ptr[3] = sptr[6]; + + sptr += 8; + ptr += 4; + } + for (; j + 1 < outw; j += 2) + { + ptr[0] = sptr[0]; + ptr[1] = sptr[2]; + + sptr += 4; + ptr += 2; + } + for (; j < outw; j++) + { + ptr[0] = sptr[0]; + + sptr += 2; + ptr += 1; + } + + sptr += gap; + } + } + } + } + } + + im2col_sgemm_pack1to4_int8_sse(bottom_im2col, top_blob, kernel, opt); +} diff --git a/src/layer/x86/convolution_sgemm_int8.h b/src/layer/x86/convolution_sgemm_int8.h index f1d1025e1..767340fb4 100644 --- a/src/layer/x86/convolution_sgemm_int8.h +++ b/src/layer/x86/convolution_sgemm_int8.h @@ -1,6 +1,6 @@ -// BUG1989 is pleased to support the open source community by supporting ncnn available. +// Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2019 BUG1989. All rights reserved. +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -12,1341 +12,1111 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -static void conv_im2col_sgemm_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, - const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt) +static void im2col_sgemm_int8_sse(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt) { - int w = bottom_blob.w; - int inch = bottom_blob.c; + // Mat bottom_im2col(size, maxk, inch, 8u, 8, opt.workspace_allocator); - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; + const int size = bottom_im2col.w; + const int maxk = bottom_im2col.h; + const int inch = bottom_im2col.c; - const signed char* kernel = _kernel; + const int outch = top_blob.c; - // im2row - Mat bottom_im2row(kernel_h * kernel_w * inch, outw * outh, 1UL, opt.workspace_allocator); + // permute + Mat tmp; +#if __SSE2__ + if (inch >= 4) { - signed char* ret = (signed char*)bottom_im2row; - int retID = 0; - - for (int i = 0; i < outh; i++) - { - for (int j = 0; j < outw; j++) - { - for (int p = 0; p < inch; p++) - { - const signed char* input = bottom_blob.channel(p); - for (int u = 0; u < kernel_h; u++) - { - for (int v = 0; v < kernel_w; v++) - { - int row = u + i * stride_h; - int col = v + j * stride_w; - int index = row * w + col; - ret[retID] = input[index]; - retID++; - } - } - } - } - } +#if __AVX2__ + if (size >= 4) + tmp.create(4 * maxk, inch / 4 + inch % 4, size / 4 + (size % 4) / 2 + size % 2, 4u, 4, opt.workspace_allocator); + else if (size >= 2) + tmp.create(2 * maxk, inch / 4 + inch % 4, size / 2 + size % 2, 4u, 4, opt.workspace_allocator); + else + tmp.create(maxk, inch / 4 + inch % 4, size, 4u, 4, opt.workspace_allocator); +#else + if (size >= 2) + tmp.create(2 * maxk, inch / 4 + inch % 4, size / 2 + size % 2, 4u, 4, opt.workspace_allocator); + else + tmp.create(maxk, inch / 4 + inch % 4, size, 4u, 4, opt.workspace_allocator); +#endif + } + else + { +#if __AVX2__ + if (size >= 4) + tmp.create(4 * maxk, inch, size / 4 + (size % 4) / 2 + size % 2, 1u, 1, opt.workspace_allocator); + else if (size >= 2) + tmp.create(2 * maxk, inch, size / 2 + size % 2, 1u, 1, opt.workspace_allocator); + else + tmp.create(maxk, inch, size, 1u, 1, opt.workspace_allocator); +#else + if (size >= 2) + tmp.create(2 * maxk, inch, size / 2 + size % 2, 1u, 1, opt.workspace_allocator); + else + tmp.create(maxk, inch, size, 1u, 1, opt.workspace_allocator); +#endif } - - int kernel_size = kernel_w * kernel_h; - int out_size = outw * outh; - - // int M = outch; // outch - int N = outw * outh; // outsize or out stride - int K = kernel_w * kernel_h * inch; // ksize * inch - - // bottom_im2row memory packed 4 x 4 - Mat bottom_tm(4 * kernel_size, inch, out_size / 4 + out_size % 4, (size_t)1u, opt.workspace_allocator); { - int nn_size = out_size >> 2; - int remain_size_start = nn_size << 2; +#if __AVX2__ + int remain_size_start = 0; + int nn_size = size >> 2; #pragma omp parallel for num_threads(opt.num_threads) for (int ii = 0; ii < nn_size; ii++) { - int i = ii * 4; - - const signed char* img0 = bottom_im2row.row(i); - const signed char* img1 = bottom_im2row.row(i + 1); - const signed char* img2 = bottom_im2row.row(i + 2); - const signed char* img3 = bottom_im2row.row(i + 3); + int i = remain_size_start + ii * 4; - signed char* tmpptr = bottom_tm.channel(i / 4); + signed char* tmpptr = tmp.channel(i / 4); int q = 0; - for (; q + 1 < inch * kernel_size; q = q + 2) - { - tmpptr[0] = img0[0]; - tmpptr[1] = img0[1]; - tmpptr[2] = img1[0]; - tmpptr[3] = img1[1]; - tmpptr[4] = img2[0]; - tmpptr[5] = img2[1]; - tmpptr[6] = img3[0]; - tmpptr[7] = img3[1]; - - tmpptr += 8; - img0 += 2; - img1 += 2; - img2 += 2; - img3 += 2; - } - - for (; q < inch * kernel_size; q++) + for (; q + 3 < inch; q += 4) { - tmpptr[0] = img0[0]; - tmpptr[1] = img1[0]; - tmpptr[2] = img2[0]; - tmpptr[3] = img3[0]; + const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; + const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i; + const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i; + const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i; - tmpptr += 4; - img0 += 1; - img1 += 1; - img2 += 1; - img3 += 1; + for (int k = 0; k < maxk; k++) + { + tmpptr[0] = img0[0]; + tmpptr[1] = img1[0]; + tmpptr[2] = img2[0]; + tmpptr[3] = img3[0]; + tmpptr[4] = img0[1]; + tmpptr[5] = img1[1]; + tmpptr[6] = img2[1]; + tmpptr[7] = img3[1]; + tmpptr[8] = img0[2]; + tmpptr[9] = img1[2]; + tmpptr[10] = img2[2]; + tmpptr[11] = img3[2]; + tmpptr[12] = img0[3]; + tmpptr[13] = img1[3]; + tmpptr[14] = img2[3]; + tmpptr[15] = img3[3]; + tmpptr += 16; + + img0 += size; + img1 += size; + img2 += size; + img3 += size; + } } - } - - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = remain_size_start; i < out_size; i++) - { - const signed char* img0 = bottom_im2row.row(i); - - signed char* tmpptr = bottom_tm.channel(i / 4 + i % 4); - - int q = 0; - for (; q + 1 < inch * kernel_size; q = q + 2) + for (; q < inch; q++) { - tmpptr[0] = img0[0]; - tmpptr[1] = img0[1]; + const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; - tmpptr += 2; - img0 += 2; - } + for (int k = 0; k < maxk; k++) + { + tmpptr[0] = img0[0]; + tmpptr[1] = img0[1]; + tmpptr[2] = img0[2]; + tmpptr[3] = img0[3]; - for (; q < inch * kernel_size; q++) - { - tmpptr[0] = img0[0]; + tmpptr += 4; - tmpptr += 1; - img0 += 1; + img0 += size; + } } } - } - // kernel memory packed 4 x 4 - Mat kernel_tm(4 * kernel_size, inch, outch / 4 + outch % 4, (size_t)1u, opt.workspace_allocator); - { - int nn_outch = 0; - int remain_outch_start = 0; - - nn_outch = outch >> 2; - remain_outch_start = nn_outch << 2; + remain_size_start += nn_size << 2; + nn_size = (size - remain_size_start) >> 1; +#else + int remain_size_start = 0; + int nn_size = (size - remain_size_start) >> 1; +#endif #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) + for (int ii = 0; ii < nn_size; ii++) { - int p = pp * 4; + int i = remain_size_start + ii * 2; - const signed char* k0 = kernel + (p + 0) * inch * kernel_size; - const signed char* k1 = kernel + (p + 1) * inch * kernel_size; - const signed char* k2 = kernel + (p + 2) * inch * kernel_size; - const signed char* k3 = kernel + (p + 3) * inch * kernel_size; - - signed char* ktmp = kernel_tm.channel(p / 4); +#if __AVX2__ + signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); +#else + signed char* tmpptr = tmp.channel(i / 2); +#endif int q = 0; - for (; q + 1 < inch * kernel_size; q += 2) + for (; q + 3 < inch; q += 4) { - ktmp[0] = k0[0]; - ktmp[1] = k0[1]; - ktmp[2] = k1[0]; - ktmp[3] = k1[1]; - ktmp[4] = k2[0]; - ktmp[5] = k2[1]; - ktmp[6] = k3[0]; - ktmp[7] = k3[1]; - - ktmp += 8; - - k0 += 2; - k1 += 2; - k2 += 2; - k3 += 2; - } + const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; + const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i; + const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i; + const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i; - for (; q < inch * kernel_size; q++) - { - ktmp[0] = k0[0]; - ktmp[1] = k1[0]; - ktmp[2] = k2[0]; - ktmp[3] = k3[0]; - ktmp += 4; - - k0 += 1; - k1 += 1; - k2 += 1; - k3 += 1; + for (int k = 0; k < maxk; k++) + { + tmpptr[0] = img0[0]; + tmpptr[1] = img1[0]; + tmpptr[2] = img2[0]; + tmpptr[3] = img3[0]; + tmpptr[4] = img0[1]; + tmpptr[5] = img1[1]; + tmpptr[6] = img2[1]; + tmpptr[7] = img3[1]; + tmpptr += 8; + + img0 += size; + img1 += size; + img2 += size; + img3 += size; + } } - } + for (; q < inch; q++) + { + const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = remain_outch_start; p < outch; p++) - { - const signed char* k0 = kernel + (p + 0) * inch * kernel_size; + for (int k = 0; k < maxk; k++) + { + tmpptr[0] = img0[0]; + tmpptr[1] = img0[1]; - signed char* ktmp = kernel_tm.channel(p / 4 + p % 4); + tmpptr += 2; - int q = 0; - for (; q + 1 < inch * kernel_size; q = q + 2) - { - ktmp[0] = k0[0]; - ktmp[1] = k0[1]; - ktmp += 2; - k0 += 2; - } - - for (; q < inch * kernel_size; q++) - { - ktmp[0] = k0[0]; - ktmp++; - k0++; + img0 += size; + } } } - } - - // 4x4 - // sgemm(int M, int N, int K, float* A, float* B, float* C) - { - // int M = outch; // outch - // int N = outw * outh; // outsize or out stride - // int L = kernel_w * kernel_h * inch; // ksize * inch - int nn_outch = 0; - int remain_outch_start = 0; - - nn_outch = outch >> 2; - remain_outch_start = nn_outch << 2; + remain_size_start += nn_size << 1; #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) + for (int i = remain_size_start; i < size; i++) { - int i = pp * 4; - - int* output0 = top_blob.channel(i); - int* output1 = top_blob.channel(i + 1); - int* output2 = top_blob.channel(i + 2); - int* output3 = top_blob.channel(i + 3); +#if __AVX2__ + signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); +#else + signed char* tmpptr = tmp.channel(i / 2 + i % 2); +#endif - int j = 0; - for (; j + 3 < N; j = j + 4) + int q = 0; + for (; q + 3 < inch; q += 4) { - signed char* vb = bottom_tm.channel(j / 4); - signed char* va = kernel_tm.channel(i / 4); - - int sum0[4] = {0}; - int sum1[4] = {0}; - int sum2[4] = {0}; - int sum3[4] = {0}; - - int k = 0; - - for (; k + 1 < K; k = k + 2) - { - for (int n = 0; n < 4; n++) - { - sum0[n] += (int)va[0] * vb[2 * n]; // k0 - sum0[n] += (int)va[1] * vb[2 * n + 1]; - - sum1[n] += (int)va[2] * vb[2 * n]; // k1 - sum1[n] += (int)va[3] * vb[2 * n + 1]; - - sum2[n] += (int)va[4] * vb[2 * n]; // k2 - sum2[n] += (int)va[5] * vb[2 * n + 1]; - - sum3[n] += (int)va[6] * vb[2 * n]; // k3 - sum3[n] += (int)va[7] * vb[2 * n + 1]; - } - - va += 8; - vb += 8; - } + const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; + const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i; + const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i; + const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i; - for (; k < K; k++) + for (int k = 0; k < maxk; k++) { - for (int n = 0; n < 4; n++) - { - sum0[n] += (int)va[0] * vb[n]; - sum1[n] += (int)va[1] * vb[n]; - sum2[n] += (int)va[2] * vb[n]; - sum3[n] += (int)va[3] * vb[n]; - } - - va += 4; - vb += 4; + tmpptr[0] = img0[0]; + tmpptr[1] = img1[0]; + tmpptr[2] = img2[0]; + tmpptr[3] = img3[0]; + tmpptr += 4; + + img0 += size; + img1 += size; + img2 += size; + img3 += size; } - - for (int n = 0; n < 4; n++) - { - output0[n] = sum0[n]; - output1[n] = sum1[n]; - output2[n] = sum2[n]; - output3[n] = sum3[n]; - } - output0 += 4; - output1 += 4; - output2 += 4; - output3 += 4; } - - for (; j < N; j++) + for (; q < inch; q++) { - int sum0 = 0; - int sum1 = 0; - int sum2 = 0; - int sum3 = 0; + const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; - signed char* vb = bottom_tm.channel(j / 4 + j % 4); - signed char* va = kernel_tm.channel(i / 4); - - int k = 0; - - for (; k + 1 < K; k = k + 2) + for (int k = 0; k < maxk; k++) { - sum0 += (int)va[0] * vb[0]; - sum0 += (int)va[1] * vb[1]; - - sum1 += (int)va[2] * vb[0]; - sum1 += (int)va[3] * vb[1]; + tmpptr[0] = img0[0]; - sum2 += (int)va[4] * vb[0]; - sum2 += (int)va[5] * vb[1]; + tmpptr += 1; - sum3 += (int)va[6] * vb[0]; - sum3 += (int)va[7] * vb[1]; - - va += 8; - vb += 2; - } - - for (; k < K; k++) - { - sum0 += (int)va[0] * vb[0]; - sum1 += (int)va[1] * vb[0]; - sum2 += (int)va[2] * vb[0]; - sum3 += (int)va[3] * vb[0]; - - va += 4; - vb += 1; + img0 += size; } - - output0[0] = sum0; - output1[0] = sum1; - output2[0] = sum2; - output3[0] = sum3; - - output0++; - output1++; - output2++; - output3++; } } - + } +#else // __SSE2__ + tmp.create(maxk, inch, size, 1u, 1, opt.workspace_allocator); + { #pragma omp parallel for num_threads(opt.num_threads) - for (int i = remain_outch_start; i < outch; i++) + for (int i = 0; i < size; i++) { - int* output = top_blob.channel(i); + signed char* tmpptr = tmp.channel(i); - int j = 0; - for (; j + 3 < N; j = j + 4) + int q = 0; + for (; q < inch; q++) { - signed char* vb = bottom_tm.channel(j / 4); - signed char* va = kernel_tm.channel(i / 4 + i % 4); - int sum[4] = {0}; + const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; - int k = 0; - for (; k + 1 < K; k = k + 2) + for (int k = 0; k < maxk; k++) { - for (int n = 0; n < 4; n++) - { - sum[n] += (int)va[0] * vb[2 * n]; - sum[n] += (int)va[1] * vb[2 * n + 1]; - } - va += 2; - vb += 8; - } + tmpptr[0] = img0[0]; - for (; k < K; k++) - { - for (int n = 0; n < 4; n++) - { - sum[n] += (int)va[0] * vb[n]; - } - va += 1; - vb += 4; - } + tmpptr += 1; - for (int n = 0; n < 4; n++) - { - output[n] = sum[n]; + img0 += size; } - output += 4; - } - - for (; j < N; j++) - { - int sum = 0; - - signed char* vb = bottom_tm.channel(j / 4 + j % 4); - signed char* va = kernel_tm.channel(i / 4 + i % 4); - - for (int k = 0; k < K; k++) - { - sum += (int)va[0] * vb[0]; - - va += 1; - vb += 1; - } - output[0] = sum; - - output++; } } } +#endif // __SSE2__ - // // sgemm(int M, int N, int K, float* A, float* B, float* C) - // { - // for (int i=0; i scale_dequant, const Option& opt) -{ - int w = bottom_blob.w; - int inch = bottom_blob.c; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; + int nn_outch = 0; + int remain_outch_start = 0; - const signed char* kernel = _kernel; - const float* bias = _bias; +#if __SSE2__ + nn_outch = outch >> 2; - // im2row - Mat bottom_im2row(kernel_h * kernel_w * inch, outw * outh, 1UL, opt.workspace_allocator); + #pragma omp parallel for num_threads(opt.num_threads) + for (int pp = 0; pp < nn_outch; pp++) { - signed char* ret = (signed char*)bottom_im2row; - int retID = 0; + int p = pp * 4; - for (int i = 0; i < outh; i++) - { - for (int j = 0; j < outw; j++) - { - for (int p = 0; p < inch; p++) - { - const signed char* input = bottom_blob.channel(p); - for (int u = 0; u < kernel_h; u++) - { - for (int v = 0; v < kernel_w; v++) - { - int row = u + i * stride_h; - int col = v + j * stride_w; - int index = row * w + col; - ret[retID] = input[index]; - retID++; - } - } - } - } - } - } - - int kernel_size = kernel_w * kernel_h; - int out_size = outw * outh; - - // int M = outch; // outch - int N = outw * outh; // outsize or out stride - int K = kernel_w * kernel_h * inch; // ksize * inch - - // bottom_im2row memory packed 4 x 4 - Mat bottom_tm(4 * kernel_size, inch, out_size / 4 + out_size % 4, (size_t)1u, opt.workspace_allocator); - { - int nn_size = out_size >> 2; - int remain_size_start = nn_size << 2; + int* outptr0 = top_blob.channel(p); + int* outptr1 = top_blob.channel(p + 1); + int* outptr2 = top_blob.channel(p + 2); + int* outptr3 = top_blob.channel(p + 3); - #pragma omp parallel for num_threads(opt.num_threads) - for (int ii = 0; ii < nn_size; ii++) + int i = 0; +#if __AVX2__ + for (; i + 3 < size; i += 4) { - int i = ii * 4; + const signed char* tmpptr = tmp.channel(i / 4); + const signed char* kptr0 = kernel.channel(p / 4); - const signed char* img0 = bottom_im2row.row(i); - const signed char* img1 = bottom_im2row.row(i + 1); - const signed char* img2 = bottom_im2row.row(i + 2); - const signed char* img3 = bottom_im2row.row(i + 3); + int nn4 = (inch / 4) * maxk; + int nn1 = (inch % 4) * maxk; - signed char* tmpptr = bottom_tm.channel(i / 4); + __m256i _sum00_12 = _mm256_setzero_si256(); + __m256i _sum20_32 = _mm256_setzero_si256(); - int q = 0; - for (; q + 1 < inch * kernel_size; q = q + 2) + if (nn4 > 0) { - tmpptr[0] = img0[0]; - tmpptr[1] = img0[1]; - tmpptr[2] = img1[0]; - tmpptr[3] = img1[1]; - tmpptr[4] = img2[0]; - tmpptr[5] = img2[1]; - tmpptr[6] = img3[0]; - tmpptr[7] = img3[1]; - - tmpptr += 8; - img0 += 2; - img1 += 2; - img2 += 2; - img3 += 2; - } - - for (; q < inch * kernel_size; q++) - { - tmpptr[0] = img0[0]; - tmpptr[1] = img1[0]; - tmpptr[2] = img2[0]; - tmpptr[3] = img3[0]; - - tmpptr += 4; - img0 += 1; - img1 += 1; - img2 += 1; - img3 += 1; - } - } - - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = remain_size_start; i < out_size; i++) - { - const signed char* img0 = bottom_im2row.row(i); - - signed char* tmpptr = bottom_tm.channel(i / 4 + i % 4); + __m256i _sum10_02 = _mm256_setzero_si256(); + __m256i _sum01_13 = _mm256_setzero_si256(); + __m256i _sum11_03 = _mm256_setzero_si256(); + __m256i _sum30_22 = _mm256_setzero_si256(); + __m256i _sum21_33 = _mm256_setzero_si256(); + __m256i _sum31_23 = _mm256_setzero_si256(); + + int j = 0; + for (; j < nn4; j++) + { + __m128i _val0123 = _mm_loadu_si128((const __m128i*)tmpptr); + __m256i _val0123_16 = _mm256_cvtepi8_epi16(_val0123); + + __m256i _val01_16 = _mm256_permute4x64_epi64(_val0123_16, _MM_SHUFFLE(1, 1, 0, 0)); + __m256i _val23_16 = _mm256_permute4x64_epi64(_val0123_16, _MM_SHUFFLE(3, 3, 2, 2)); + + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m256i _w01_16 = _mm256_cvtepi8_epi16(_w01); + + __m256i _val10_16 = _mm256_permute4x64_epi64(_val01_16, 78); + __m256i _val32_16 = _mm256_permute4x64_epi64(_val23_16, 78); + + __m256i _sl00_11 = _mm256_mullo_epi16(_val01_16, _w01_16); + __m256i _sh00_11 = _mm256_mulhi_epi16(_val01_16, _w01_16); + __m256i _sl10_01 = _mm256_mullo_epi16(_val10_16, _w01_16); + __m256i _sh10_01 = _mm256_mulhi_epi16(_val10_16, _w01_16); + __m256i _sl20_31 = _mm256_mullo_epi16(_val23_16, _w01_16); + __m256i _sh20_31 = _mm256_mulhi_epi16(_val23_16, _w01_16); + __m256i _sl30_21 = _mm256_mullo_epi16(_val32_16, _w01_16); + __m256i _sh30_21 = _mm256_mulhi_epi16(_val32_16, _w01_16); + + _sum00_12 = _mm256_add_epi32(_sum00_12, _mm256_unpacklo_epi16(_sl00_11, _sh00_11)); + _sum10_02 = _mm256_add_epi32(_sum10_02, _mm256_unpacklo_epi16(_sl10_01, _sh10_01)); + _sum01_13 = _mm256_add_epi32(_sum01_13, _mm256_unpackhi_epi16(_sl00_11, _sh00_11)); + _sum11_03 = _mm256_add_epi32(_sum11_03, _mm256_unpackhi_epi16(_sl10_01, _sh10_01)); + _sum20_32 = _mm256_add_epi32(_sum20_32, _mm256_unpacklo_epi16(_sl20_31, _sh20_31)); + _sum30_22 = _mm256_add_epi32(_sum30_22, _mm256_unpacklo_epi16(_sl30_21, _sh30_21)); + _sum21_33 = _mm256_add_epi32(_sum21_33, _mm256_unpackhi_epi16(_sl20_31, _sh20_31)); + _sum31_23 = _mm256_add_epi32(_sum31_23, _mm256_unpackhi_epi16(_sl30_21, _sh30_21)); + + tmpptr += 16; + kptr0 += 16; + } - int q = 0; - for (; q + 1 < inch * kernel_size; q = q + 2) - { - tmpptr[0] = img0[0]; - tmpptr[1] = img0[1]; + // transpose 4x8 + { + __m256i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm256_unpacklo_epi32(_sum00_12, _sum10_02); + _tmp1 = _mm256_unpacklo_epi32(_sum01_13, _sum11_03); + _tmp2 = _mm256_unpackhi_epi32(_sum00_12, _sum10_02); + _tmp3 = _mm256_unpackhi_epi32(_sum01_13, _sum11_03); + _sum00_12 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _sum10_02 = _mm256_unpackhi_epi64(_tmp0, _tmp1); + _sum01_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); + _sum11_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); + } + { + __m256i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm256_unpacklo_epi32(_sum20_32, _sum30_22); + _tmp1 = _mm256_unpacklo_epi32(_sum21_33, _sum31_23); + _tmp2 = _mm256_unpackhi_epi32(_sum20_32, _sum30_22); + _tmp3 = _mm256_unpackhi_epi32(_sum21_33, _sum31_23); + _sum20_32 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _sum30_22 = _mm256_unpackhi_epi64(_tmp0, _tmp1); + _sum21_33 = _mm256_unpacklo_epi64(_tmp2, _tmp3); + _sum31_23 = _mm256_unpackhi_epi64(_tmp2, _tmp3); + } - tmpptr += 2; - img0 += 2; - } + _sum00_12 = _mm256_add_epi32(_sum00_12, _sum10_02); + _sum01_13 = _mm256_add_epi32(_sum01_13, _sum11_03); + _sum00_12 = _mm256_add_epi32(_sum00_12, _sum01_13); - for (; q < inch * kernel_size; q++) - { - tmpptr[0] = img0[0]; + _sum20_32 = _mm256_add_epi32(_sum20_32, _sum30_22); + _sum21_33 = _mm256_add_epi32(_sum21_33, _sum31_23); + _sum20_32 = _mm256_add_epi32(_sum20_32, _sum21_33); - tmpptr += 1; - img0 += 1; + __m256i _perm_mask = _mm256_set_epi32(6, 4, 3, 1, 7, 5, 2, 0); + _sum00_12 = _mm256_permutevar8x32_epi32(_sum00_12, _perm_mask); + _sum20_32 = _mm256_permutevar8x32_epi32(_sum20_32, _perm_mask); } - } - } - - // kernel memory packed 4 x 4 - Mat kernel_tm(4 * kernel_size, inch, outch / 4 + outch % 4, (size_t)1u, opt.workspace_allocator); - { - int nn_outch = 0; - int remain_outch_start = 0; - nn_outch = outch >> 2; - remain_outch_start = nn_outch << 2; + __m128i _sum00 = _mm256_extracti128_si256(_sum00_12, 0); + __m128i _sum10 = _mm256_extracti128_si256(_sum00_12, 1); + __m128i _sum20 = _mm256_extracti128_si256(_sum20_32, 0); + __m128i _sum30 = _mm256_extracti128_si256(_sum20_32, 1); - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) - { - int p = pp * 4; - - const signed char* k0 = kernel + (p + 0) * inch * kernel_size; - const signed char* k1 = kernel + (p + 1) * inch * kernel_size; - const signed char* k2 = kernel + (p + 2) * inch * kernel_size; - const signed char* k3 = kernel + (p + 3) * inch * kernel_size; - - signed char* ktmp = kernel_tm.channel(p / 4); - - int q = 0; - for (; q + 1 < inch * kernel_size; q += 2) + int j = 0; + for (; j < nn1; j++) { - ktmp[0] = k0[0]; - ktmp[1] = k0[1]; - ktmp[2] = k1[0]; - ktmp[3] = k1[1]; - ktmp[4] = k2[0]; - ktmp[5] = k2[1]; - ktmp[6] = k3[0]; - ktmp[7] = k3[1]; - - ktmp += 8; - - k0 += 2; - k1 += 2; - k2 += 2; - k3 += 2; - } + __m128i _val01 = _mm_set_epi16(tmpptr[1], tmpptr[1], tmpptr[1], tmpptr[1], tmpptr[0], tmpptr[0], tmpptr[0], tmpptr[0]); + __m128i _val23 = _mm_set_epi16(tmpptr[3], tmpptr[3], tmpptr[3], tmpptr[3], tmpptr[2], tmpptr[2], tmpptr[2], tmpptr[2]); - for (; q < inch * kernel_size; q++) - { - ktmp[0] = k0[0]; - ktmp[1] = k1[0]; - ktmp[2] = k2[0]; - ktmp[3] = k3[0]; - ktmp += 4; - - k0 += 1; - k1 += 1; - k2 += 1; - k3 += 1; - } - } + __m128i _w0123 = _mm_set_epi16(kptr0[3], kptr0[2], kptr0[1], kptr0[0], kptr0[3], kptr0[2], kptr0[1], kptr0[0]); - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = remain_outch_start; p < outch; p++) - { - const signed char* k0 = kernel + (p + 0) * inch * kernel_size; + __m128i _sl00 = _mm_mullo_epi16(_val01, _w0123); + __m128i _sh00 = _mm_mulhi_epi16(_val01, _w0123); + __m128i _sl10 = _mm_mullo_epi16(_val23, _w0123); + __m128i _sh10 = _mm_mulhi_epi16(_val23, _w0123); - signed char* ktmp = kernel_tm.channel(p / 4 + p % 4); + _sum00 = _mm_add_epi32(_sum00, _mm_unpacklo_epi16(_sl00, _sh00)); + _sum10 = _mm_add_epi32(_sum10, _mm_unpackhi_epi16(_sl00, _sh00)); + _sum20 = _mm_add_epi32(_sum20, _mm_unpacklo_epi16(_sl10, _sh10)); + _sum30 = _mm_add_epi32(_sum30, _mm_unpackhi_epi16(_sl10, _sh10)); - int q = 0; - for (; q + 1 < inch * kernel_size; q = q + 2) - { - ktmp[0] = k0[0]; - ktmp[1] = k0[1]; - ktmp += 2; - k0 += 2; + tmpptr += 4; + kptr0 += 4; } - for (; q < inch * kernel_size; q++) - { - ktmp[0] = k0[0]; - ktmp++; - k0++; - } + int sum[16]; + _mm_storeu_si128((__m128i*)sum, _sum00); + _mm_storeu_si128((__m128i*)(sum + 4), _sum10); + _mm_storeu_si128((__m128i*)(sum + 8), _sum20); + _mm_storeu_si128((__m128i*)(sum + 12), _sum30); + + outptr0[0] = sum[0]; + outptr1[0] = sum[1]; + outptr2[0] = sum[2]; + outptr3[0] = sum[3]; + outptr0[1] = sum[4]; + outptr1[1] = sum[5]; + outptr2[1] = sum[6]; + outptr3[1] = sum[7]; + outptr0[2] = sum[8]; + outptr1[2] = sum[9]; + outptr2[2] = sum[10]; + outptr3[2] = sum[11]; + outptr0[3] = sum[12]; + outptr1[3] = sum[13]; + outptr2[3] = sum[14]; + outptr3[3] = sum[15]; + outptr0 += 4; + outptr1 += 4; + outptr2 += 4; + outptr3 += 4; } - } - - // 4x4 - // sgemm(int M, int N, int K, float* A, float* B, float* C) - { - // int M = outch; // outch - // int N = outw * outh; // outsize or out stride - // int L = kernel_w * kernel_h * inch; // ksize * inch - - int nn_outch = 0; - int remain_outch_start = 0; - - nn_outch = outch >> 2; - remain_outch_start = nn_outch << 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) +#endif + for (; i + 1 < size; i += 2) { - int i = pp * 4; - - const float bias0 = bias ? bias[i] : 0.f; - const float bias1 = bias ? bias[i + 1] : 0.f; - const float bias2 = bias ? bias[i + 2] : 0.f; - const float bias3 = bias ? bias[i + 3] : 0.f; - - const float scale_dequant0 = scale_dequant[i]; - const float scale_dequant1 = scale_dequant[i + 1]; - const float scale_dequant2 = scale_dequant[i + 2]; - const float scale_dequant3 = scale_dequant[i + 3]; - - float* output0 = top_blob.channel(i); - float* output1 = top_blob.channel(i + 1); - float* output2 = top_blob.channel(i + 2); - float* output3 = top_blob.channel(i + 3); - - int j = 0; - for (; j + 3 < N; j = j + 4) +#if __AVX2__ + const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); +#else + const signed char* tmpptr = tmp.channel(i / 2); +#endif + const signed char* kptr0 = kernel.channel(p / 4); + + int nn4 = (inch / 4) * maxk; + int nn1 = (inch % 4) * maxk; + +#if __AVX2__ + __m256i _sum00_12 = _mm256_setzero_si256(); +#else + __m128i _sum00 = _mm_setzero_si128(); + __m128i _sum10 = _mm_setzero_si128(); +#endif + + if (nn4 > 0) { - signed char* vb = bottom_tm.channel(j / 4); - signed char* va = kernel_tm.channel(i / 4); - - int sum0[4] = {0}; - int sum1[4] = {0}; - int sum2[4] = {0}; - int sum3[4] = {0}; - - int k = 0; - - for (; k + 1 < K; k = k + 2) +#if __AVX2__ + __m256i _sum10_02 = _mm256_setzero_si256(); + __m256i _sum01_13 = _mm256_setzero_si256(); + __m256i _sum11_03 = _mm256_setzero_si256(); +#else + __m128i _sum01 = _mm_setzero_si128(); + __m128i _sum02 = _mm_setzero_si128(); + __m128i _sum03 = _mm_setzero_si128(); + __m128i _sum11 = _mm_setzero_si128(); + __m128i _sum12 = _mm_setzero_si128(); + __m128i _sum13 = _mm_setzero_si128(); +#endif + + int j = 0; + for (; j < nn4; j++) { - for (int n = 0; n < 4; n++) - { - sum0[n] += (int)va[0] * vb[2 * n]; // k0 - sum0[n] += (int)va[1] * vb[2 * n + 1]; - - sum1[n] += (int)va[2] * vb[2 * n]; // k1 - sum1[n] += (int)va[3] * vb[2 * n + 1]; - - sum2[n] += (int)va[4] * vb[2 * n]; // k2 - sum2[n] += (int)va[5] * vb[2 * n + 1]; - - sum3[n] += (int)va[6] * vb[2 * n]; // k3 - sum3[n] += (int)va[7] * vb[2 * n + 1]; - } - - va += 8; - vb += 8; +#if __AVX2__ + __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr); + __m256i _val01_16 = _mm256_cvtepi8_epi16(_val01); + + _val01_16 = _mm256_permute4x64_epi64(_val01_16, _MM_SHUFFLE(1, 1, 0, 0)); + + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m256i _w01_16 = _mm256_cvtepi8_epi16(_w01); + + __m256i _val10_16 = _mm256_permute4x64_epi64(_val01_16, 78); + + __m256i _sl00_11 = _mm256_mullo_epi16(_val01_16, _w01_16); + __m256i _sh00_11 = _mm256_mulhi_epi16(_val01_16, _w01_16); + __m256i _sl10_01 = _mm256_mullo_epi16(_val10_16, _w01_16); + __m256i _sh10_01 = _mm256_mulhi_epi16(_val10_16, _w01_16); + + _sum00_12 = _mm256_add_epi32(_sum00_12, _mm256_unpacklo_epi16(_sl00_11, _sh00_11)); + _sum10_02 = _mm256_add_epi32(_sum10_02, _mm256_unpacklo_epi16(_sl10_01, _sh10_01)); + _sum01_13 = _mm256_add_epi32(_sum01_13, _mm256_unpackhi_epi16(_sl00_11, _sh00_11)); + _sum11_03 = _mm256_add_epi32(_sum11_03, _mm256_unpackhi_epi16(_sl10_01, _sh10_01)); +#else + // TODO use _mm_cvtepi8_epi16 on sse4.1 + __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr); + __m128i _extval01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _val01); + _val01 = _mm_unpacklo_epi8(_val01, _extval01); + + __m128i _val0 = _mm_shuffle_epi32(_val01, _MM_SHUFFLE(1, 0, 1, 0)); + __m128i _val1 = _mm_shuffle_epi32(_val01, _MM_SHUFFLE(3, 2, 3, 2)); + + // TODO use _mm_cvtepi8_epi16 on sse4.1 + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01); + __m128i _w0 = _mm_unpacklo_epi8(_w01, _extw01); + __m128i _w1 = _mm_unpackhi_epi8(_w01, _extw01); + + __m128i _sl00 = _mm_mullo_epi16(_val0, _w0); + __m128i _sh00 = _mm_mulhi_epi16(_val0, _w0); + __m128i _sl01 = _mm_mullo_epi16(_val0, _w1); + __m128i _sh01 = _mm_mulhi_epi16(_val0, _w1); + __m128i _sl10 = _mm_mullo_epi16(_val1, _w0); + __m128i _sh10 = _mm_mulhi_epi16(_val1, _w0); + __m128i _sl11 = _mm_mullo_epi16(_val1, _w1); + __m128i _sh11 = _mm_mulhi_epi16(_val1, _w1); + + _sum00 = _mm_add_epi32(_sum00, _mm_unpacklo_epi16(_sl00, _sh00)); + _sum01 = _mm_add_epi32(_sum01, _mm_unpackhi_epi16(_sl00, _sh00)); + _sum02 = _mm_add_epi32(_sum02, _mm_unpacklo_epi16(_sl01, _sh01)); + _sum03 = _mm_add_epi32(_sum03, _mm_unpackhi_epi16(_sl01, _sh01)); + _sum10 = _mm_add_epi32(_sum10, _mm_unpacklo_epi16(_sl10, _sh10)); + _sum11 = _mm_add_epi32(_sum11, _mm_unpackhi_epi16(_sl10, _sh10)); + _sum12 = _mm_add_epi32(_sum12, _mm_unpacklo_epi16(_sl11, _sh11)); + _sum13 = _mm_add_epi32(_sum13, _mm_unpackhi_epi16(_sl11, _sh11)); +#endif + + tmpptr += 8; + kptr0 += 16; } - for (; k < K; k++) +#if __AVX2__ + // transpose 4x8 { - for (int n = 0; n < 4; n++) - { - sum0[n] += (int)va[0] * vb[n]; - sum1[n] += (int)va[1] * vb[n]; - sum2[n] += (int)va[2] * vb[n]; - sum3[n] += (int)va[3] * vb[n]; - } - - va += 4; - vb += 4; + __m256i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm256_unpacklo_epi32(_sum00_12, _sum10_02); + _tmp1 = _mm256_unpacklo_epi32(_sum01_13, _sum11_03); + _tmp2 = _mm256_unpackhi_epi32(_sum00_12, _sum10_02); + _tmp3 = _mm256_unpackhi_epi32(_sum01_13, _sum11_03); + _sum00_12 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _sum10_02 = _mm256_unpackhi_epi64(_tmp0, _tmp1); + _sum01_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); + _sum11_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); } - for (int n = 0; n < 4; n++) + _sum00_12 = _mm256_add_epi32(_sum00_12, _sum10_02); + _sum01_13 = _mm256_add_epi32(_sum01_13, _sum11_03); + _sum00_12 = _mm256_add_epi32(_sum00_12, _sum01_13); + + __m256i _perm_mask = _mm256_set_epi32(6, 4, 3, 1, 7, 5, 2, 0); + _sum00_12 = _mm256_permutevar8x32_epi32(_sum00_12, _perm_mask); +#else + // transpose 4x4 { - output0[n] = (float)sum0[n] * scale_dequant0 + bias0; - output1[n] = (float)sum1[n] * scale_dequant1 + bias1; - output2[n] = (float)sum2[n] * scale_dequant2 + bias2; - output3[n] = (float)sum3[n] * scale_dequant3 + bias3; + __m128i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm_unpacklo_epi32(_sum00, _sum01); + _tmp1 = _mm_unpacklo_epi32(_sum02, _sum03); + _tmp2 = _mm_unpackhi_epi32(_sum00, _sum01); + _tmp3 = _mm_unpackhi_epi32(_sum02, _sum03); + _sum00 = _mm_unpacklo_epi64(_tmp0, _tmp1); + _sum01 = _mm_unpackhi_epi64(_tmp0, _tmp1); + _sum02 = _mm_unpacklo_epi64(_tmp2, _tmp3); + _sum03 = _mm_unpackhi_epi64(_tmp2, _tmp3); } - output0 += 4; - output1 += 4; - output2 += 4; - output3 += 4; - } - - for (; j < N; j++) - { - int sum0 = 0; - int sum1 = 0; - int sum2 = 0; - int sum3 = 0; - - signed char* vb = bottom_tm.channel(j / 4 + j % 4); - signed char* va = kernel_tm.channel(i / 4); - - int k = 0; - - for (; k + 1 < K; k = k + 2) { - sum0 += (int)va[0] * vb[0]; - sum0 += (int)va[1] * vb[1]; + __m128i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm_unpacklo_epi32(_sum10, _sum11); + _tmp1 = _mm_unpacklo_epi32(_sum12, _sum13); + _tmp2 = _mm_unpackhi_epi32(_sum10, _sum11); + _tmp3 = _mm_unpackhi_epi32(_sum12, _sum13); + _sum10 = _mm_unpacklo_epi64(_tmp0, _tmp1); + _sum11 = _mm_unpackhi_epi64(_tmp0, _tmp1); + _sum12 = _mm_unpacklo_epi64(_tmp2, _tmp3); + _sum13 = _mm_unpackhi_epi64(_tmp2, _tmp3); + } - sum1 += (int)va[2] * vb[0]; - sum1 += (int)va[3] * vb[1]; + _sum00 = _mm_add_epi32(_sum00, _sum01); + _sum02 = _mm_add_epi32(_sum02, _sum03); + _sum10 = _mm_add_epi32(_sum10, _sum11); + _sum12 = _mm_add_epi32(_sum12, _sum13); - sum2 += (int)va[4] * vb[0]; - sum2 += (int)va[5] * vb[1]; + _sum00 = _mm_add_epi32(_sum00, _sum02); + _sum10 = _mm_add_epi32(_sum10, _sum12); +#endif + } - sum3 += (int)va[6] * vb[0]; - sum3 += (int)va[7] * vb[1]; +#if __AVX2__ + __m128i _sum00 = _mm256_extracti128_si256(_sum00_12, 0); + __m128i _sum10 = _mm256_extracti128_si256(_sum00_12, 1); +#endif - va += 8; - vb += 2; - } + int j = 0; + for (; j < nn1; j++) + { + __m128i _val = _mm_set_epi16(tmpptr[1], tmpptr[1], tmpptr[1], tmpptr[1], tmpptr[0], tmpptr[0], tmpptr[0], tmpptr[0]); - for (; k < K; k++) - { - sum0 += (int)va[0] * vb[0]; - sum1 += (int)va[1] * vb[0]; - sum2 += (int)va[2] * vb[0]; - sum3 += (int)va[3] * vb[0]; + __m128i _w0123 = _mm_set_epi16(kptr0[3], kptr0[2], kptr0[1], kptr0[0], kptr0[3], kptr0[2], kptr0[1], kptr0[0]); - va += 4; - vb += 1; - } + __m128i _sl00 = _mm_mullo_epi16(_val, _w0123); + __m128i _sh00 = _mm_mulhi_epi16(_val, _w0123); - output0[0] = (float)sum0 * scale_dequant0 + bias0; - output1[0] = (float)sum1 * scale_dequant1 + bias1; - output2[0] = (float)sum2 * scale_dequant2 + bias2; - output3[0] = (float)sum3 * scale_dequant3 + bias3; + _sum00 = _mm_add_epi32(_sum00, _mm_unpacklo_epi16(_sl00, _sh00)); + _sum10 = _mm_add_epi32(_sum10, _mm_unpackhi_epi16(_sl00, _sh00)); - output0++; - output1++; - output2++; - output3++; + tmpptr += 2; + kptr0 += 4; } - } - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = remain_outch_start; i < outch; i++) + int sum[8]; + _mm_storeu_si128((__m128i*)sum, _sum00); + _mm_storeu_si128((__m128i*)(sum + 4), _sum10); + + outptr0[0] = sum[0]; + outptr1[0] = sum[1]; + outptr2[0] = sum[2]; + outptr3[0] = sum[3]; + outptr0[1] = sum[4]; + outptr1[1] = sum[5]; + outptr2[1] = sum[6]; + outptr3[1] = sum[7]; + outptr0 += 2; + outptr1 += 2; + outptr2 += 2; + outptr3 += 2; + } + for (; i < size; i++) { - float* output = top_blob.channel(i); +#if __AVX2__ + const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); +#else + const signed char* tmpptr = tmp.channel(i / 2 + i % 2); +#endif + const signed char* kptr0 = kernel.channel(p / 4); - const float bias0 = bias ? bias[i] : 0.f; - const float scale_dequant0 = scale_dequant[i]; + int nn4 = (inch / 4) * maxk; + int nn1 = (inch % 4) * maxk; - int j = 0; - for (; j + 3 < N; j = j + 4) + __m128i _sum0 = _mm_setzero_si128(); + + if (nn4 > 0) { - signed char* vb = bottom_tm.channel(j / 4); - signed char* va = kernel_tm.channel(i / 4 + i % 4); - int sum[4] = {0}; + __m128i _sum1 = _mm_setzero_si128(); + __m128i _sum2 = _mm_setzero_si128(); + __m128i _sum3 = _mm_setzero_si128(); - int k = 0; - for (; k + 1 < K; k = k + 2) + int j = 0; + for (; j < nn4; j++) { - for (int n = 0; n < 4; n++) - { - sum[n] += (int)va[0] * vb[2 * n]; - sum[n] += (int)va[1] * vb[2 * n + 1]; - } - va += 2; - vb += 8; + // TODO use _mm_cvtepi8_epi16 on sse4.1 + __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr); + __m128i _extval01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _val01); + __m128i _val0 = _mm_unpacklo_epi8(_val01, _extval01); + + _val0 = _mm_shuffle_epi32(_val0, _MM_SHUFFLE(1, 0, 1, 0)); + + // TODO use _mm_cvtepi8_epi16 on sse4.1 + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01); + __m128i _w0 = _mm_unpacklo_epi8(_w01, _extw01); + __m128i _w1 = _mm_unpackhi_epi8(_w01, _extw01); + + __m128i _sl00 = _mm_mullo_epi16(_val0, _w0); + __m128i _sh00 = _mm_mulhi_epi16(_val0, _w0); + __m128i _sl01 = _mm_mullo_epi16(_val0, _w1); + __m128i _sh01 = _mm_mulhi_epi16(_val0, _w1); + + _sum0 = _mm_add_epi32(_sum0, _mm_unpacklo_epi16(_sl00, _sh00)); + _sum1 = _mm_add_epi32(_sum1, _mm_unpackhi_epi16(_sl00, _sh00)); + _sum2 = _mm_add_epi32(_sum2, _mm_unpacklo_epi16(_sl01, _sh01)); + _sum3 = _mm_add_epi32(_sum3, _mm_unpackhi_epi16(_sl01, _sh01)); + + tmpptr += 4; + kptr0 += 16; } - for (; k < K; k++) + // transpose 4x4 { - for (int n = 0; n < 4; n++) - { - sum[n] += (int)va[0] * vb[n]; - } - va += 1; - vb += 4; + __m128i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1); + _tmp1 = _mm_unpacklo_epi32(_sum2, _sum3); + _tmp2 = _mm_unpackhi_epi32(_sum0, _sum1); + _tmp3 = _mm_unpackhi_epi32(_sum2, _sum3); + _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1); + _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp1); + _sum2 = _mm_unpacklo_epi64(_tmp2, _tmp3); + _sum3 = _mm_unpackhi_epi64(_tmp2, _tmp3); } - for (int n = 0; n < 4; n++) - { - output[n] = (float)sum[n] * scale_dequant0 + bias0; - } - output += 4; + _sum0 = _mm_add_epi32(_sum0, _sum1); + _sum2 = _mm_add_epi32(_sum2, _sum3); + _sum0 = _mm_add_epi32(_sum0, _sum2); } - for (; j < N; j++) + int j = 0; + for (; j < nn1; j++) { - int sum = 0; + __m128i _val = _mm_set1_epi16(tmpptr[0]); - signed char* vb = bottom_tm.channel(j / 4 + j % 4); - signed char* va = kernel_tm.channel(i / 4 + i % 4); + __m128i _w0123 = _mm_set_epi16(0, 0, 0, 0, kptr0[3], kptr0[2], kptr0[1], kptr0[0]); - for (int k = 0; k < K; k++) - { - sum += (int)va[0] * vb[0]; + __m128i _sl00 = _mm_mullo_epi16(_val, _w0123); + __m128i _sh00 = _mm_mulhi_epi16(_val, _w0123); - va += 1; - vb += 1; - } - output[0] = (float)sum * scale_dequant0 + bias0; + _sum0 = _mm_add_epi32(_sum0, _mm_unpacklo_epi16(_sl00, _sh00)); - output++; + tmpptr += 1; + kptr0 += 4; } + + int sum[4]; + _mm_storeu_si128((__m128i*)sum, _sum0); + + outptr0[0] = sum[0]; + outptr1[0] = sum[1]; + outptr2[0] = sum[2]; + outptr3[0] = sum[3]; + outptr0 += 1; + outptr1 += 1; + outptr2 += 1; + outptr3 += 1; } } - // // sgemm(int M, int N, int K, float* A, float* B, float* C) - // { - // for (int i=0; i scale_requant, const Option& opt) -{ - int w = bottom_blob.w; - int inch = bottom_blob.c; + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = remain_outch_start; p < outch; p++) + { + int* outptr0 = top_blob.channel(p); - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; + int i = 0; +#if __SSE2__ +#if __AVX2__ + for (; i + 3 < size; i += 4) + { + const signed char* tmpptr = tmp.channel(i / 4); + const signed char* kptr0 = kernel.channel(p / 4 + p % 4); - const signed char* kernel = _kernel; - const float* bias = _bias; + int nn4 = (inch / 4) * maxk; + int nn1 = (inch % 4) * maxk; - // im2row - Mat bottom_im2row(kernel_h * kernel_w * inch, outw * outh, 1UL, opt.workspace_allocator); - { - signed char* ret = (signed char*)bottom_im2row; - int retID = 0; + int sum0 = 0; + int sum1 = 0; + int sum2 = 0; + int sum3 = 0; - for (int i = 0; i < outh; i++) - { - for (int j = 0; j < outw; j++) + if (nn4 > 0) { - for (int p = 0; p < inch; p++) + int j = 0; + for (; j < nn4; j++) { - const signed char* input = bottom_blob.channel(p); - for (int u = 0; u < kernel_h; u++) - { - for (int v = 0; v < kernel_w; v++) - { - int row = u + i * stride_h; - int col = v + j * stride_w; - int index = row * w + col; - ret[retID] = input[index]; - retID++; - } - } + signed char val0 = tmpptr[0]; + signed char val1 = tmpptr[1]; + signed char val2 = tmpptr[2]; + signed char val3 = tmpptr[3]; + signed char val4 = tmpptr[4]; + signed char val5 = tmpptr[5]; + signed char val6 = tmpptr[6]; + signed char val7 = tmpptr[7]; + signed char val8 = tmpptr[8]; + signed char val9 = tmpptr[9]; + signed char val10 = tmpptr[10]; + signed char val11 = tmpptr[11]; + signed char val12 = tmpptr[12]; + signed char val13 = tmpptr[13]; + signed char val14 = tmpptr[14]; + signed char val15 = tmpptr[15]; + + signed char w0 = kptr0[0]; + signed char w1 = kptr0[1]; + signed char w2 = kptr0[2]; + signed char w3 = kptr0[3]; + + sum0 += val0 * w0; + sum0 += val1 * w1; + sum0 += val2 * w2; + sum0 += val3 * w3; + sum1 += val4 * w0; + sum1 += val5 * w1; + sum1 += val6 * w2; + sum1 += val7 * w3; + sum2 += val8 * w0; + sum2 += val9 * w1; + sum2 += val10 * w2; + sum2 += val11 * w3; + sum3 += val12 * w0; + sum3 += val13 * w1; + sum3 += val14 * w2; + sum3 += val15 * w3; + + tmpptr += 16; + kptr0 += 4; } } - } - } - int kernel_size = kernel_w * kernel_h; - int out_size = outw * outh; + int j = 0; + for (; j < nn1; j++) + { + signed char val0 = tmpptr[0]; + signed char val1 = tmpptr[1]; + signed char val2 = tmpptr[2]; + signed char val3 = tmpptr[3]; + signed char w = kptr0[0]; - // int M = outch; // outch - int N = outw * outh; // outsize or out stride - int K = kernel_w * kernel_h * inch; // ksize * inch + sum0 += val0 * w; + sum1 += val1 * w; + sum2 += val2 * w; + sum3 += val3 * w; - // bottom_im2row memory packed 4 x 4 - Mat bottom_tm(4 * kernel_size, inch, out_size / 4 + out_size % 4, (size_t)1u, opt.workspace_allocator); - { - int nn_size = out_size >> 2; - int remain_size_start = nn_size << 2; + tmpptr += 4; + kptr0 += 1; + } - #pragma omp parallel for num_threads(opt.num_threads) - for (int ii = 0; ii < nn_size; ii++) + outptr0[0] = sum0; + outptr0[1] = sum1; + outptr0[2] = sum2; + outptr0[3] = sum3; + outptr0 += 4; + } +#endif + for (; i + 1 < size; i += 2) { - int i = ii * 4; +#if __AVX2__ + const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); +#else + const signed char* tmpptr = tmp.channel(i / 2); +#endif + const signed char* kptr0 = kernel.channel(p / 4 + p % 4); - const signed char* img0 = bottom_im2row.row(i); - const signed char* img1 = bottom_im2row.row(i + 1); - const signed char* img2 = bottom_im2row.row(i + 2); - const signed char* img3 = bottom_im2row.row(i + 3); + int nn4 = (inch / 4) * maxk; + int nn1 = (inch % 4) * maxk; - signed char* tmpptr = bottom_tm.channel(i / 4); + int sum0 = 0; + int sum1 = 0; - int q = 0; - for (; q + 1 < inch * kernel_size; q = q + 2) + if (nn4 > 0) { - tmpptr[0] = img0[0]; - tmpptr[1] = img0[1]; - tmpptr[2] = img1[0]; - tmpptr[3] = img1[1]; - tmpptr[4] = img2[0]; - tmpptr[5] = img2[1]; - tmpptr[6] = img3[0]; - tmpptr[7] = img3[1]; - - tmpptr += 8; - img0 += 2; - img1 += 2; - img2 += 2; - img3 += 2; + int j = 0; + for (; j < nn4; j++) + { + signed char val0 = tmpptr[0]; + signed char val1 = tmpptr[1]; + signed char val2 = tmpptr[2]; + signed char val3 = tmpptr[3]; + signed char val4 = tmpptr[4]; + signed char val5 = tmpptr[5]; + signed char val6 = tmpptr[6]; + signed char val7 = tmpptr[7]; + + signed char w0 = kptr0[0]; + signed char w1 = kptr0[1]; + signed char w2 = kptr0[2]; + signed char w3 = kptr0[3]; + + sum0 += val0 * w0; + sum0 += val1 * w1; + sum0 += val2 * w2; + sum0 += val3 * w3; + sum1 += val4 * w0; + sum1 += val5 * w1; + sum1 += val6 * w2; + sum1 += val7 * w3; + + tmpptr += 8; + kptr0 += 4; + } } - for (; q < inch * kernel_size; q++) + int j = 0; + for (; j < nn1; j++) { - tmpptr[0] = img0[0]; - tmpptr[1] = img1[0]; - tmpptr[2] = img2[0]; - tmpptr[3] = img3[0]; + signed char val0 = tmpptr[0]; + signed char val1 = tmpptr[1]; + signed char w = kptr0[0]; - tmpptr += 4; - img0 += 1; - img1 += 1; - img2 += 1; - img3 += 1; - } - } - - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = remain_size_start; i < out_size; i++) - { - const signed char* img0 = bottom_im2row.row(i); - - signed char* tmpptr = bottom_tm.channel(i / 4 + i % 4); - - int q = 0; - for (; q + 1 < inch * kernel_size; q = q + 2) - { - tmpptr[0] = img0[0]; - tmpptr[1] = img0[1]; + sum0 += val0 * w; + sum1 += val1 * w; tmpptr += 2; - img0 += 2; + kptr0 += 1; } - for (; q < inch * kernel_size; q++) - { - tmpptr[0] = img0[0]; - - tmpptr += 1; - img0 += 1; - } + outptr0[0] = sum0; + outptr0[1] = sum1; + outptr0 += 2; } - } - - // kernel memory packed 4 x 4 - Mat kernel_tm(4 * kernel_size, inch, outch / 4 + outch % 4, (size_t)1u, opt.workspace_allocator); - { - int nn_outch = 0; - int remain_outch_start = 0; - - nn_outch = outch >> 2; - remain_outch_start = nn_outch << 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) + for (; i < size; i++) { - int p = pp * 4; +#if __AVX2__ + const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); +#else + const signed char* tmpptr = tmp.channel(i / 2 + i % 2); +#endif + const signed char* kptr0 = kernel.channel(p / 4 + p % 4); - const signed char* k0 = kernel + (p + 0) * inch * kernel_size; - const signed char* k1 = kernel + (p + 1) * inch * kernel_size; - const signed char* k2 = kernel + (p + 2) * inch * kernel_size; - const signed char* k3 = kernel + (p + 3) * inch * kernel_size; + int nn4 = (inch / 4) * maxk; + int nn1 = (inch % 4) * maxk; - signed char* ktmp = kernel_tm.channel(p / 4); + int sum = 0; - int q = 0; - for (; q + 1 < inch * kernel_size; q += 2) + if (nn4 > 0) { - ktmp[0] = k0[0]; - ktmp[1] = k0[1]; - ktmp[2] = k1[0]; - ktmp[3] = k1[1]; - ktmp[4] = k2[0]; - ktmp[5] = k2[1]; - ktmp[6] = k3[0]; - ktmp[7] = k3[1]; - - ktmp += 8; - - k0 += 2; - k1 += 2; - k2 += 2; - k3 += 2; + int j = 0; + for (; j < nn4; j++) + { + signed char val0 = tmpptr[0]; + signed char val1 = tmpptr[1]; + signed char val2 = tmpptr[2]; + signed char val3 = tmpptr[3]; + + signed char w0 = kptr0[0]; + signed char w1 = kptr0[1]; + signed char w2 = kptr0[2]; + signed char w3 = kptr0[3]; + + sum += val0 * w0; + sum += val1 * w1; + sum += val2 * w2; + sum += val3 * w3; + + tmpptr += 4; + kptr0 += 4; + } } - for (; q < inch * kernel_size; q++) + int j = 0; + for (; j < nn1; j++) { - ktmp[0] = k0[0]; - ktmp[1] = k1[0]; - ktmp[2] = k2[0]; - ktmp[3] = k3[0]; - ktmp += 4; - - k0 += 1; - k1 += 1; - k2 += 1; - k3 += 1; + signed char val = tmpptr[0]; + signed char w = kptr0[0]; + + sum += val * w; + + tmpptr += 1; + kptr0 += 1; } - } - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = remain_outch_start; p < outch; p++) + outptr0[0] = sum; + outptr0 += 1; + } +#else // __SSE2__ + for (; i < size; i++) { - const signed char* k0 = kernel + (p + 0) * inch * kernel_size; + const signed char* tmpptr = tmp.channel(i); + const signed char* kptr0 = kernel.channel(p); - signed char* ktmp = kernel_tm.channel(p / 4 + p % 4); + int nn1 = inch * maxk; - int q = 0; - for (; q + 1 < inch * kernel_size; q = q + 2) + int sum = 0; + int j = 0; + for (; j < nn1; j++) { - ktmp[0] = k0[0]; - ktmp[1] = k0[1]; - ktmp += 2; - k0 += 2; - } + signed char val = tmpptr[0]; + signed char w = kptr0[0]; - for (; q < inch * kernel_size; q++) - { - ktmp[0] = k0[0]; - ktmp++; - k0++; + sum += val * w; + + tmpptr += 1; + kptr0 += 1; } + + outptr0[0] = sum; + outptr0 += 1; } +#endif // __SSE2__ } +} - // 4x4 - // sgemm(int M, int N, int K, float* A, float* B, float* C) +static void convolution_im2col_sgemm_transform_kernel_int8_sse(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) +{ + const int maxk = kernel_w * kernel_h; + +#if __SSE2__ + // interleave + // src = maxk-inch-outch + // dst = 4a-4b-maxk-inch/4a-outch/4b + Mat kernel = _kernel.reshape(maxk, inch, outch); + if (outch >= 4) { - // int M = outch; // outch - // int N = outw * outh; // outsize or out stride - // int L = kernel_w * kernel_h * inch; // ksize * inch - - int nn_outch = 0; - int remain_outch_start = 0; + if (inch >= 4) + kernel_tm.create(16 * maxk, inch / 4 + inch % 4, outch / 4 + outch % 4, (size_t)1u); + else + kernel_tm.create(4 * maxk, inch, outch / 4 + outch % 4, (size_t)1u); + } + else + { + if (inch >= 4) + kernel_tm.create(4 * maxk, inch / 4 + inch % 4, outch, (size_t)1u); + else + kernel_tm.create(1 * maxk, inch, outch, (size_t)1u); + } - nn_outch = outch >> 2; - remain_outch_start = nn_outch << 2; + int q = 0; + for (; q + 3 < outch; q += 4) + { + signed char* g00 = kernel_tm.channel(q / 4); - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) + int p = 0; + for (; p + 3 < inch; p += 4) { - int i = pp * 4; - - signed char* output0 = top_blob.channel(i); - signed char* output1 = top_blob.channel(i + 1); - signed char* output2 = top_blob.channel(i + 2); - signed char* output3 = top_blob.channel(i + 3); - - const float bias0 = bias ? bias[i] : 0.f; - const float bias1 = bias ? bias[i + 1] : 0.f; - const float bias2 = bias ? bias[i + 2] : 0.f; - const float bias3 = bias ? bias[i + 3] : 0.f; - - const float scale_requant_in0 = scale_requant[2 * i]; - const float scale_requant_out0 = scale_requant[2 * i + 1]; - const float scale_requant_in1 = scale_requant[2 * (i + 1)]; - const float scale_requant_out1 = scale_requant[2 * (i + 1) + 1]; - const float scale_requant_in2 = scale_requant[2 * (i + 2)]; - const float scale_requant_out2 = scale_requant[2 * (i + 2) + 1]; - const float scale_requant_in3 = scale_requant[2 * (i + 3)]; - const float scale_requant_out3 = scale_requant[2 * (i + 3) + 1]; - - int j = 0; - for (; j + 3 < N; j = j + 4) + for (int k = 0; k < maxk; k++) { - signed char* vb = bottom_tm.channel(j / 4); - signed char* va = kernel_tm.channel(i / 4); - - int sum0[4] = {0}; - int sum1[4] = {0}; - int sum2[4] = {0}; - int sum3[4] = {0}; - - int k = 0; - - for (; k + 1 < K; k = k + 2) + for (int i = 0; i < 4; i++) { - for (int n = 0; n < 4; n++) + for (int j = 0; j < 4; j++) { - sum0[n] += (int)va[0] * vb[2 * n]; // k0 - sum0[n] += (int)va[1] * vb[2 * n + 1]; + const signed char* k00 = kernel.channel(q + i).row(p + j); - sum1[n] += (int)va[2] * vb[2 * n]; // k1 - sum1[n] += (int)va[3] * vb[2 * n + 1]; + g00[0] = k00[k]; - sum2[n] += (int)va[4] * vb[2 * n]; // k2 - sum2[n] += (int)va[5] * vb[2 * n + 1]; - - sum3[n] += (int)va[6] * vb[2 * n]; // k3 - sum3[n] += (int)va[7] * vb[2 * n + 1]; + g00++; } - - va += 8; - vb += 8; } - - for (; k < K; k++) + } + } + for (; p < inch; p++) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 4; i++) { - for (int n = 0; n < 4; n++) - { - sum0[n] += (int)va[0] * vb[n]; - sum1[n] += (int)va[1] * vb[n]; - sum2[n] += (int)va[2] * vb[n]; - sum3[n] += (int)va[3] * vb[n]; - } + const signed char* k00 = kernel.channel(q + i).row(p); - va += 4; - vb += 4; - } + g00[0] = k00[k]; - for (int n = 0; n < 4; n++) - { - output0[n] = float2int8(((float)sum0[n] * scale_requant_in0 + bias0) * scale_requant_out0); - output1[n] = float2int8(((float)sum1[n] * scale_requant_in1 + bias1) * scale_requant_out1); - output2[n] = float2int8(((float)sum2[n] * scale_requant_in2 + bias2) * scale_requant_out2); - output3[n] = float2int8(((float)sum3[n] * scale_requant_in3 + bias3) * scale_requant_out3); + g00++; } - output0 += 4; - output1 += 4; - output2 += 4; - output3 += 4; } + } + } + // TODO unroll 2 + for (; q < outch; q++) + { + signed char* g00 = kernel_tm.channel(q / 4 + q % 4); - for (; j < N; j++) + int p = 0; + for (; p + 3 < inch; p += 4) + { + for (int k = 0; k < maxk; k++) { - int sum0 = 0; - int sum1 = 0; - int sum2 = 0; - int sum3 = 0; - - signed char* vb = bottom_tm.channel(j / 4 + j % 4); - signed char* va = kernel_tm.channel(i / 4); - - int k = 0; - - for (; k + 1 < K; k = k + 2) + for (int j = 0; j < 4; j++) { - sum0 += (int)va[0] * vb[0]; - sum0 += (int)va[1] * vb[1]; + const signed char* k00 = kernel.channel(q).row(p + j); - sum1 += (int)va[2] * vb[0]; - sum1 += (int)va[3] * vb[1]; + g00[0] = k00[k]; - sum2 += (int)va[4] * vb[0]; - sum2 += (int)va[5] * vb[1]; + g00++; + } + } + } + for (; p < inch; p++) + { + for (int k = 0; k < maxk; k++) + { + const signed char* k00 = kernel.channel(q).row(p); - sum3 += (int)va[6] * vb[0]; - sum3 += (int)va[7] * vb[1]; + g00[0] = k00[k]; - va += 8; - vb += 2; - } + g00++; + } + } + } +#else // __SSE2__ + kernel_tm = _kernel.reshape(maxk, inch, outch); +#endif // __SSE2__ +} - for (; k < K; k++) - { - sum0 += (int)va[0] * vb[0]; - sum1 += (int)va[1] * vb[0]; - sum2 += (int)va[2] * vb[0]; - sum3 += (int)va[3] * vb[0]; +static void convolution_im2col_sgemm_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) +{ + int w = bottom_blob.w; + int inch = bottom_blob.c; - va += 4; - vb += 1; - } + int outw = top_blob.w; + int outh = top_blob.h; + const int size = outw * outh; - output0[0] = float2int8(((float)sum0 * scale_requant_in0 + bias0) * scale_requant_out0); - output1[0] = float2int8(((float)sum1 * scale_requant_in1 + bias1) * scale_requant_out1); - output2[0] = float2int8(((float)sum2 * scale_requant_in2 + bias2) * scale_requant_out2); - output3[0] = float2int8(((float)sum3 * scale_requant_in3 + bias3) * scale_requant_out3); + const int maxk = kernel_w * kernel_h; - output0++; - output1++; - output2++; - output3++; - } - } + // im2col + Mat bottom_im2col(size, maxk, inch, 1u, 1, opt.workspace_allocator); + { + const int gap = w * stride_h - outw * stride_w; #pragma omp parallel for num_threads(opt.num_threads) - for (int i = remain_outch_start; i < outch; i++) + for (int p = 0; p < inch; p++) { - signed char* output = top_blob.channel(i); - - const float bias0 = bias ? bias[i] : 0.f; - - const float scale_requant_in0 = scale_requant[2 * i]; - const float scale_requant_out0 = scale_requant[2 * i + 1]; + const Mat img = bottom_blob.channel(p); + signed char* ptr = bottom_im2col.channel(p); - int j = 0; - for (; j + 3 < N; j = j + 4) + for (int u = 0; u < kernel_h; u++) { - signed char* vb = bottom_tm.channel(j / 4); - signed char* va = kernel_tm.channel(i / 4 + i % 4); - int sum[4] = {0}; - - int k = 0; - for (; k + 1 < K; k = k + 2) + for (int v = 0; v < kernel_w; v++) { - for (int n = 0; n < 4; n++) - { - sum[n] += (int)va[0] * vb[2 * n]; - sum[n] += (int)va[1] * vb[2 * n + 1]; - } - va += 2; - vb += 8; - } + const signed char* sptr = img.row(dilation_h * u) + dilation_w * v; - for (; k < K; k++) - { - for (int n = 0; n < 4; n++) + for (int i = 0; i < outh; i++) { - sum[n] += (int)va[0] * vb[n]; - } - va += 1; - vb += 4; - } - - for (int n = 0; n < 4; n++) - { - output[n] = float2int8(((float)sum[n] * scale_requant_in0 + bias0) * scale_requant_out0); - } - output += 4; - } + int j = 0; + for (; j + 3 < outw; j += 4) + { + ptr[0] = sptr[0]; + ptr[1] = sptr[stride_w]; + ptr[2] = sptr[stride_w * 2]; + ptr[3] = sptr[stride_w * 3]; - for (; j < N; j++) - { - int sum = 0; + sptr += stride_w * 4; + ptr += 4; + } + for (; j + 1 < outw; j += 2) + { + ptr[0] = sptr[0]; + ptr[1] = sptr[stride_w]; - signed char* vb = bottom_tm.channel(j / 4 + j % 4); - signed char* va = kernel_tm.channel(i / 4 + i % 4); + sptr += stride_w * 2; + ptr += 2; + } + for (; j < outw; j++) + { + ptr[0] = sptr[0]; - for (int k = 0; k < K; k++) - { - sum += (int)va[0] * vb[0]; + sptr += stride_w; + ptr += 1; + } - va += 1; - vb += 1; + sptr += gap; + } } - output[0] = float2int8(((float)sum * scale_requant_in0 + bias0) * scale_requant_out0); - - output++; } } } - // // sgemm(int M, int N, int K, float* A, float* B, float* C) - // { - // for (int i=0; i= 4) + { +#if __AVX2__ + if (size >= 4) + tmp.create(4 * maxk, inch / 4 + inch % 4, size / 4 + (size % 4) / 2 + size % 2, 4u, 4, opt.workspace_allocator); + else if (size >= 2) + tmp.create(2 * maxk, inch / 4 + inch % 4, size / 2 + size % 2, 4u, 4, opt.workspace_allocator); + else + tmp.create(maxk, inch / 4 + inch % 4, size, 4u, 4, opt.workspace_allocator); +#else + if (size >= 2) + tmp.create(2 * maxk, inch / 4 + inch % 4, size / 2 + size % 2, 4u, 4, opt.workspace_allocator); + else + tmp.create(maxk, inch / 4 + inch % 4, size, 4u, 4, opt.workspace_allocator); +#endif + } + else + { +#if __AVX2__ + if (size >= 4) + tmp.create(4 * maxk, inch, size / 4 + (size % 4) / 2 + size % 2, 1u, 1, opt.workspace_allocator); + else if (size >= 2) + tmp.create(2 * maxk, inch, size / 2 + size % 2, 1u, 1, opt.workspace_allocator); + else + tmp.create(maxk, inch, size, 1u, 1, opt.workspace_allocator); +#else + if (size >= 2) + tmp.create(2 * maxk, inch, size / 2 + size % 2, 1u, 1, opt.workspace_allocator); + else + tmp.create(maxk, inch, size, 1u, 1, opt.workspace_allocator); +#endif + } + { +#if __AVX2__ + int remain_size_start = 0; + int nn_size = size >> 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 4; + + signed char* tmpptr = tmp.channel(i / 4); + + int q = 0; + for (; q + 3 < inch; q += 4) + { + const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; + const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i; + const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i; + const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i; + + for (int k = 0; k < maxk; k++) + { + tmpptr[0] = img0[0]; + tmpptr[1] = img1[0]; + tmpptr[2] = img2[0]; + tmpptr[3] = img3[0]; + tmpptr[4] = img0[1]; + tmpptr[5] = img1[1]; + tmpptr[6] = img2[1]; + tmpptr[7] = img3[1]; + tmpptr[8] = img0[2]; + tmpptr[9] = img1[2]; + tmpptr[10] = img2[2]; + tmpptr[11] = img3[2]; + tmpptr[12] = img0[3]; + tmpptr[13] = img1[3]; + tmpptr[14] = img2[3]; + tmpptr[15] = img3[3]; + tmpptr += 16; + + img0 += size; + img1 += size; + img2 += size; + img3 += size; + } + } + for (; q < inch; q++) + { + const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; + + for (int k = 0; k < maxk; k++) + { + tmpptr[0] = img0[0]; + tmpptr[1] = img0[1]; + tmpptr[2] = img0[2]; + tmpptr[3] = img0[3]; + + tmpptr += 4; + + img0 += size; + } + } + } + + remain_size_start += nn_size << 2; + nn_size = (size - remain_size_start) >> 1; +#else + int remain_size_start = 0; + int nn_size = (size - remain_size_start) >> 1; +#endif + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 2; + +#if __AVX2__ + signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); +#else + signed char* tmpptr = tmp.channel(i / 2); +#endif + + int q = 0; + for (; q + 3 < inch; q += 4) + { + const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; + const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i; + const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i; + const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i; + + for (int k = 0; k < maxk; k++) + { + tmpptr[0] = img0[0]; + tmpptr[1] = img1[0]; + tmpptr[2] = img2[0]; + tmpptr[3] = img3[0]; + tmpptr[4] = img0[1]; + tmpptr[5] = img1[1]; + tmpptr[6] = img2[1]; + tmpptr[7] = img3[1]; + tmpptr += 8; + + img0 += size; + img1 += size; + img2 += size; + img3 += size; + } + } + for (; q < inch; q++) + { + const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; + + for (int k = 0; k < maxk; k++) + { + tmpptr[0] = img0[0]; + tmpptr[1] = img0[1]; + + tmpptr += 2; + + img0 += size; + } + } + } + + remain_size_start += nn_size << 1; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = remain_size_start; i < size; i++) + { +#if __AVX2__ + signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); +#else + signed char* tmpptr = tmp.channel(i / 2 + i % 2); +#endif + + int q = 0; + for (; q + 3 < inch; q += 4) + { + const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; + const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i; + const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i; + const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i; + + for (int k = 0; k < maxk; k++) + { + tmpptr[0] = img0[0]; + tmpptr[1] = img1[0]; + tmpptr[2] = img2[0]; + tmpptr[3] = img3[0]; + tmpptr += 4; + + img0 += size; + img1 += size; + img2 += size; + img3 += size; + } + } + for (; q < inch; q++) + { + const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i; + + for (int k = 0; k < maxk; k++) + { + tmpptr[0] = img0[0]; + + tmpptr += 1; + + img0 += size; + } + } + } + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + int* outptr0 = top_blob.channel(p); + + int i = 0; +#if __AVX2__ + for (; i + 3 < size; i += 4) + { + const signed char* tmpptr = tmp.channel(i / 4); + const signed char* kptr0 = kernel.channel(p / 4); + + int nn4 = (inch / 4) * maxk; + int nn1 = (inch % 4) * maxk; + + __m256i _sum00_12 = _mm256_setzero_si256(); + __m256i _sum20_32 = _mm256_setzero_si256(); + + if (nn4 > 0) + { + __m256i _sum10_02 = _mm256_setzero_si256(); + __m256i _sum01_13 = _mm256_setzero_si256(); + __m256i _sum11_03 = _mm256_setzero_si256(); + __m256i _sum30_22 = _mm256_setzero_si256(); + __m256i _sum21_33 = _mm256_setzero_si256(); + __m256i _sum31_23 = _mm256_setzero_si256(); + + int j = 0; + for (; j < nn4; j++) + { + __m128i _val0123 = _mm_loadu_si128((const __m128i*)tmpptr); + __m256i _val0123_16 = _mm256_cvtepi8_epi16(_val0123); + + __m256i _val01_16 = _mm256_permute4x64_epi64(_val0123_16, _MM_SHUFFLE(1, 1, 0, 0)); + __m256i _val23_16 = _mm256_permute4x64_epi64(_val0123_16, _MM_SHUFFLE(3, 3, 2, 2)); + + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m256i _w01_16 = _mm256_cvtepi8_epi16(_w01); + + __m256i _val10_16 = _mm256_permute4x64_epi64(_val01_16, 78); + __m256i _val32_16 = _mm256_permute4x64_epi64(_val23_16, 78); + + __m256i _sl00_11 = _mm256_mullo_epi16(_val01_16, _w01_16); + __m256i _sh00_11 = _mm256_mulhi_epi16(_val01_16, _w01_16); + __m256i _sl10_01 = _mm256_mullo_epi16(_val10_16, _w01_16); + __m256i _sh10_01 = _mm256_mulhi_epi16(_val10_16, _w01_16); + __m256i _sl20_31 = _mm256_mullo_epi16(_val23_16, _w01_16); + __m256i _sh20_31 = _mm256_mulhi_epi16(_val23_16, _w01_16); + __m256i _sl30_21 = _mm256_mullo_epi16(_val32_16, _w01_16); + __m256i _sh30_21 = _mm256_mulhi_epi16(_val32_16, _w01_16); + + _sum00_12 = _mm256_add_epi32(_sum00_12, _mm256_unpacklo_epi16(_sl00_11, _sh00_11)); + _sum10_02 = _mm256_add_epi32(_sum10_02, _mm256_unpacklo_epi16(_sl10_01, _sh10_01)); + _sum01_13 = _mm256_add_epi32(_sum01_13, _mm256_unpackhi_epi16(_sl00_11, _sh00_11)); + _sum11_03 = _mm256_add_epi32(_sum11_03, _mm256_unpackhi_epi16(_sl10_01, _sh10_01)); + _sum20_32 = _mm256_add_epi32(_sum20_32, _mm256_unpacklo_epi16(_sl20_31, _sh20_31)); + _sum30_22 = _mm256_add_epi32(_sum30_22, _mm256_unpacklo_epi16(_sl30_21, _sh30_21)); + _sum21_33 = _mm256_add_epi32(_sum21_33, _mm256_unpackhi_epi16(_sl20_31, _sh20_31)); + _sum31_23 = _mm256_add_epi32(_sum31_23, _mm256_unpackhi_epi16(_sl30_21, _sh30_21)); + + tmpptr += 16; + kptr0 += 16; + } + + // transpose 4x8 + { + __m256i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm256_unpacklo_epi32(_sum00_12, _sum10_02); + _tmp1 = _mm256_unpacklo_epi32(_sum01_13, _sum11_03); + _tmp2 = _mm256_unpackhi_epi32(_sum00_12, _sum10_02); + _tmp3 = _mm256_unpackhi_epi32(_sum01_13, _sum11_03); + _sum00_12 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _sum10_02 = _mm256_unpackhi_epi64(_tmp0, _tmp1); + _sum01_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); + _sum11_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); + } + { + __m256i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm256_unpacklo_epi32(_sum20_32, _sum30_22); + _tmp1 = _mm256_unpacklo_epi32(_sum21_33, _sum31_23); + _tmp2 = _mm256_unpackhi_epi32(_sum20_32, _sum30_22); + _tmp3 = _mm256_unpackhi_epi32(_sum21_33, _sum31_23); + _sum20_32 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _sum30_22 = _mm256_unpackhi_epi64(_tmp0, _tmp1); + _sum21_33 = _mm256_unpacklo_epi64(_tmp2, _tmp3); + _sum31_23 = _mm256_unpackhi_epi64(_tmp2, _tmp3); + } + + _sum00_12 = _mm256_add_epi32(_sum00_12, _sum10_02); + _sum01_13 = _mm256_add_epi32(_sum01_13, _sum11_03); + _sum00_12 = _mm256_add_epi32(_sum00_12, _sum01_13); + + _sum20_32 = _mm256_add_epi32(_sum20_32, _sum30_22); + _sum21_33 = _mm256_add_epi32(_sum21_33, _sum31_23); + _sum20_32 = _mm256_add_epi32(_sum20_32, _sum21_33); + + __m256i _perm_mask = _mm256_set_epi32(6, 4, 3, 1, 7, 5, 2, 0); + _sum00_12 = _mm256_permutevar8x32_epi32(_sum00_12, _perm_mask); + _sum20_32 = _mm256_permutevar8x32_epi32(_sum20_32, _perm_mask); + } + + __m128i _sum00 = _mm256_extracti128_si256(_sum00_12, 0); + __m128i _sum10 = _mm256_extracti128_si256(_sum00_12, 1); + __m128i _sum20 = _mm256_extracti128_si256(_sum20_32, 0); + __m128i _sum30 = _mm256_extracti128_si256(_sum20_32, 1); + + int j = 0; + for (; j < nn1; j++) + { + __m128i _val01 = _mm_set_epi16(tmpptr[1], tmpptr[1], tmpptr[1], tmpptr[1], tmpptr[0], tmpptr[0], tmpptr[0], tmpptr[0]); + __m128i _val23 = _mm_set_epi16(tmpptr[3], tmpptr[3], tmpptr[3], tmpptr[3], tmpptr[2], tmpptr[2], tmpptr[2], tmpptr[2]); + + __m128i _w0123 = _mm_set_epi16(kptr0[3], kptr0[2], kptr0[1], kptr0[0], kptr0[3], kptr0[2], kptr0[1], kptr0[0]); + + __m128i _sl00 = _mm_mullo_epi16(_val01, _w0123); + __m128i _sh00 = _mm_mulhi_epi16(_val01, _w0123); + __m128i _sl10 = _mm_mullo_epi16(_val23, _w0123); + __m128i _sh10 = _mm_mulhi_epi16(_val23, _w0123); + + _sum00 = _mm_add_epi32(_sum00, _mm_unpacklo_epi16(_sl00, _sh00)); + _sum10 = _mm_add_epi32(_sum10, _mm_unpackhi_epi16(_sl00, _sh00)); + _sum20 = _mm_add_epi32(_sum20, _mm_unpacklo_epi16(_sl10, _sh10)); + _sum30 = _mm_add_epi32(_sum30, _mm_unpackhi_epi16(_sl10, _sh10)); + + tmpptr += 4; + kptr0 += 4; + } + + _mm_storeu_si128((__m128i*)outptr0, _sum00); + _mm_storeu_si128((__m128i*)(outptr0 + 4), _sum10); + _mm_storeu_si128((__m128i*)(outptr0 + 8), _sum20); + _mm_storeu_si128((__m128i*)(outptr0 + 12), _sum30); + outptr0 += 16; + } +#endif + for (; i + 1 < size; i += 2) + { +#if __AVX2__ + const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); +#else + const signed char* tmpptr = tmp.channel(i / 2); +#endif + const signed char* kptr0 = kernel.channel(p); + + int nn4 = (inch / 4) * maxk; + int nn1 = (inch % 4) * maxk; + +#if __AVX2__ + __m256i _sum00_12 = _mm256_setzero_si256(); +#else + __m128i _sum00 = _mm_setzero_si128(); + __m128i _sum10 = _mm_setzero_si128(); +#endif + + if (nn4 > 0) + { +#if __AVX2__ + __m256i _sum10_02 = _mm256_setzero_si256(); + __m256i _sum01_13 = _mm256_setzero_si256(); + __m256i _sum11_03 = _mm256_setzero_si256(); +#else + __m128i _sum01 = _mm_setzero_si128(); + __m128i _sum02 = _mm_setzero_si128(); + __m128i _sum03 = _mm_setzero_si128(); + __m128i _sum11 = _mm_setzero_si128(); + __m128i _sum12 = _mm_setzero_si128(); + __m128i _sum13 = _mm_setzero_si128(); +#endif + + int j = 0; + for (; j < nn4; j++) + { +#if __AVX2__ + __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr); + __m256i _val01_16 = _mm256_cvtepi8_epi16(_val01); + + _val01_16 = _mm256_permute4x64_epi64(_val01_16, _MM_SHUFFLE(1, 1, 0, 0)); + + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m256i _w01_16 = _mm256_cvtepi8_epi16(_w01); + + __m256i _val10_16 = _mm256_permute4x64_epi64(_val01_16, 78); + + __m256i _sl00_11 = _mm256_mullo_epi16(_val01_16, _w01_16); + __m256i _sh00_11 = _mm256_mulhi_epi16(_val01_16, _w01_16); + __m256i _sl10_01 = _mm256_mullo_epi16(_val10_16, _w01_16); + __m256i _sh10_01 = _mm256_mulhi_epi16(_val10_16, _w01_16); + + _sum00_12 = _mm256_add_epi32(_sum00_12, _mm256_unpacklo_epi16(_sl00_11, _sh00_11)); + _sum10_02 = _mm256_add_epi32(_sum10_02, _mm256_unpacklo_epi16(_sl10_01, _sh10_01)); + _sum01_13 = _mm256_add_epi32(_sum01_13, _mm256_unpackhi_epi16(_sl00_11, _sh00_11)); + _sum11_03 = _mm256_add_epi32(_sum11_03, _mm256_unpackhi_epi16(_sl10_01, _sh10_01)); +#else + // TODO use _mm_cvtepi8_epi16 on sse4.1 + __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr); + __m128i _extval01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _val01); + _val01 = _mm_unpacklo_epi8(_val01, _extval01); + + __m128i _val0 = _mm_shuffle_epi32(_val01, _MM_SHUFFLE(1, 0, 1, 0)); + __m128i _val1 = _mm_shuffle_epi32(_val01, _MM_SHUFFLE(3, 2, 3, 2)); + + // TODO use _mm_cvtepi8_epi16 on sse4.1 + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01); + __m128i _w0 = _mm_unpacklo_epi8(_w01, _extw01); + __m128i _w1 = _mm_unpackhi_epi8(_w01, _extw01); + + __m128i _sl00 = _mm_mullo_epi16(_val0, _w0); + __m128i _sh00 = _mm_mulhi_epi16(_val0, _w0); + __m128i _sl01 = _mm_mullo_epi16(_val0, _w1); + __m128i _sh01 = _mm_mulhi_epi16(_val0, _w1); + __m128i _sl10 = _mm_mullo_epi16(_val1, _w0); + __m128i _sh10 = _mm_mulhi_epi16(_val1, _w0); + __m128i _sl11 = _mm_mullo_epi16(_val1, _w1); + __m128i _sh11 = _mm_mulhi_epi16(_val1, _w1); + + _sum00 = _mm_add_epi32(_sum00, _mm_unpacklo_epi16(_sl00, _sh00)); + _sum01 = _mm_add_epi32(_sum01, _mm_unpackhi_epi16(_sl00, _sh00)); + _sum02 = _mm_add_epi32(_sum02, _mm_unpacklo_epi16(_sl01, _sh01)); + _sum03 = _mm_add_epi32(_sum03, _mm_unpackhi_epi16(_sl01, _sh01)); + _sum10 = _mm_add_epi32(_sum10, _mm_unpacklo_epi16(_sl10, _sh10)); + _sum11 = _mm_add_epi32(_sum11, _mm_unpackhi_epi16(_sl10, _sh10)); + _sum12 = _mm_add_epi32(_sum12, _mm_unpacklo_epi16(_sl11, _sh11)); + _sum13 = _mm_add_epi32(_sum13, _mm_unpackhi_epi16(_sl11, _sh11)); +#endif + + tmpptr += 8; + kptr0 += 16; + } + +#if __AVX2__ + // transpose 4x8 + { + __m256i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm256_unpacklo_epi32(_sum00_12, _sum10_02); + _tmp1 = _mm256_unpacklo_epi32(_sum01_13, _sum11_03); + _tmp2 = _mm256_unpackhi_epi32(_sum00_12, _sum10_02); + _tmp3 = _mm256_unpackhi_epi32(_sum01_13, _sum11_03); + _sum00_12 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _sum10_02 = _mm256_unpackhi_epi64(_tmp0, _tmp1); + _sum01_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); + _sum11_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); + } + + _sum00_12 = _mm256_add_epi32(_sum00_12, _sum10_02); + _sum01_13 = _mm256_add_epi32(_sum01_13, _sum11_03); + _sum00_12 = _mm256_add_epi32(_sum00_12, _sum01_13); + + __m256i _perm_mask = _mm256_set_epi32(6, 4, 3, 1, 7, 5, 2, 0); + _sum00_12 = _mm256_permutevar8x32_epi32(_sum00_12, _perm_mask); +#else + // transpose 4x4 + { + __m128i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm_unpacklo_epi32(_sum00, _sum01); + _tmp1 = _mm_unpacklo_epi32(_sum02, _sum03); + _tmp2 = _mm_unpackhi_epi32(_sum00, _sum01); + _tmp3 = _mm_unpackhi_epi32(_sum02, _sum03); + _sum00 = _mm_unpacklo_epi64(_tmp0, _tmp1); + _sum01 = _mm_unpackhi_epi64(_tmp0, _tmp1); + _sum02 = _mm_unpacklo_epi64(_tmp2, _tmp3); + _sum03 = _mm_unpackhi_epi64(_tmp2, _tmp3); + } + { + __m128i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm_unpacklo_epi32(_sum10, _sum11); + _tmp1 = _mm_unpacklo_epi32(_sum12, _sum13); + _tmp2 = _mm_unpackhi_epi32(_sum10, _sum11); + _tmp3 = _mm_unpackhi_epi32(_sum12, _sum13); + _sum10 = _mm_unpacklo_epi64(_tmp0, _tmp1); + _sum11 = _mm_unpackhi_epi64(_tmp0, _tmp1); + _sum12 = _mm_unpacklo_epi64(_tmp2, _tmp3); + _sum13 = _mm_unpackhi_epi64(_tmp2, _tmp3); + } + + _sum00 = _mm_add_epi32(_sum00, _sum01); + _sum02 = _mm_add_epi32(_sum02, _sum03); + _sum10 = _mm_add_epi32(_sum10, _sum11); + _sum12 = _mm_add_epi32(_sum12, _sum13); + + _sum00 = _mm_add_epi32(_sum00, _sum02); + _sum10 = _mm_add_epi32(_sum10, _sum12); +#endif + } + +#if __AVX2__ + __m128i _sum00 = _mm256_extracti128_si256(_sum00_12, 0); + __m128i _sum10 = _mm256_extracti128_si256(_sum00_12, 1); +#endif + + int j = 0; + for (; j < nn1; j++) + { + __m128i _val = _mm_set_epi16(tmpptr[1], tmpptr[1], tmpptr[1], tmpptr[1], tmpptr[0], tmpptr[0], tmpptr[0], tmpptr[0]); + + __m128i _w0123 = _mm_set_epi16(kptr0[3], kptr0[2], kptr0[1], kptr0[0], kptr0[3], kptr0[2], kptr0[1], kptr0[0]); + + __m128i _sl00 = _mm_mullo_epi16(_val, _w0123); + __m128i _sh00 = _mm_mulhi_epi16(_val, _w0123); + + _sum00 = _mm_add_epi32(_sum00, _mm_unpacklo_epi16(_sl00, _sh00)); + _sum10 = _mm_add_epi32(_sum10, _mm_unpackhi_epi16(_sl00, _sh00)); + + tmpptr += 2; + kptr0 += 4; + } + + _mm_storeu_si128((__m128i*)outptr0, _sum00); + _mm_storeu_si128((__m128i*)(outptr0 + 4), _sum10); + outptr0 += 8; + } + for (; i < size; i++) + { +#if __AVX2__ + const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); +#else + const signed char* tmpptr = tmp.channel(i / 2 + i % 2); +#endif + const signed char* kptr0 = kernel.channel(p); + + int nn4 = (inch / 4) * maxk; + int nn1 = (inch % 4) * maxk; + + __m128i _sum0 = _mm_setzero_si128(); + + if (nn4 > 0) + { + __m128i _sum1 = _mm_setzero_si128(); + __m128i _sum2 = _mm_setzero_si128(); + __m128i _sum3 = _mm_setzero_si128(); + + int j = 0; + for (; j < nn4; j++) + { + // TODO use _mm_cvtepi8_epi16 on sse4.1 + __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr); + __m128i _extval01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _val01); + __m128i _val0 = _mm_unpacklo_epi8(_val01, _extval01); + + _val0 = _mm_shuffle_epi32(_val0, _MM_SHUFFLE(1, 0, 1, 0)); + + // TODO use _mm_cvtepi8_epi16 on sse4.1 + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01); + __m128i _w0 = _mm_unpacklo_epi8(_w01, _extw01); + __m128i _w1 = _mm_unpackhi_epi8(_w01, _extw01); + + __m128i _sl00 = _mm_mullo_epi16(_val0, _w0); + __m128i _sh00 = _mm_mulhi_epi16(_val0, _w0); + __m128i _sl01 = _mm_mullo_epi16(_val0, _w1); + __m128i _sh01 = _mm_mulhi_epi16(_val0, _w1); + + _sum0 = _mm_add_epi32(_sum0, _mm_unpacklo_epi16(_sl00, _sh00)); + _sum1 = _mm_add_epi32(_sum1, _mm_unpackhi_epi16(_sl00, _sh00)); + _sum2 = _mm_add_epi32(_sum2, _mm_unpacklo_epi16(_sl01, _sh01)); + _sum3 = _mm_add_epi32(_sum3, _mm_unpackhi_epi16(_sl01, _sh01)); + + tmpptr += 4; + kptr0 += 16; + } + + // transpose 4x4 + { + __m128i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1); + _tmp1 = _mm_unpacklo_epi32(_sum2, _sum3); + _tmp2 = _mm_unpackhi_epi32(_sum0, _sum1); + _tmp3 = _mm_unpackhi_epi32(_sum2, _sum3); + _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1); + _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp1); + _sum2 = _mm_unpacklo_epi64(_tmp2, _tmp3); + _sum3 = _mm_unpackhi_epi64(_tmp2, _tmp3); + } + + _sum0 = _mm_add_epi32(_sum0, _sum1); + _sum2 = _mm_add_epi32(_sum2, _sum3); + _sum0 = _mm_add_epi32(_sum0, _sum2); + } + + int j = 0; + for (; j < nn1; j++) + { + __m128i _val = _mm_set1_epi16(tmpptr[0]); + + __m128i _w0123 = _mm_set_epi16(0, 0, 0, 0, kptr0[3], kptr0[2], kptr0[1], kptr0[0]); + + __m128i _sl00 = _mm_mullo_epi16(_val, _w0123); + __m128i _sh00 = _mm_mulhi_epi16(_val, _w0123); + + _sum0 = _mm_add_epi32(_sum0, _mm_unpacklo_epi16(_sl00, _sh00)); + + tmpptr += 1; + kptr0 += 4; + } + + _mm_storeu_si128((__m128i*)outptr0, _sum0); + outptr0 += 4; + } + } +} + +static void convolution_im2col_sgemm_transform_kernel_pack1to4_int8_sse(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) +{ + const int maxk = kernel_w * kernel_h; + + // interleave + // src = maxk-inch-outch + // dst = 4a-4b-maxk-inch/4a-outch/4b + Mat kernel = _kernel.reshape(maxk, inch, outch); + if (inch >= 4) + kernel_tm.create(16 * maxk, inch / 4 + inch % 4, outch / 4, (size_t)1u); + else + kernel_tm.create(4 * maxk, inch, outch / 4, (size_t)1u); + + for (int q = 0; q + 3 < outch; q += 4) + { + signed char* g00 = kernel_tm.channel(q / 4); + + int p = 0; + for (; p + 3 < inch; p += 4) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 4; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + + g00[0] = k00[k]; + + g00++; + } + } + } + } + for (; p < inch; p++) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 4; i++) + { + const signed char* k00 = kernel.channel(q + i).row(p); + + g00[0] = k00[k]; + + g00++; + } + } + } + } +} + +static void convolution_im2col_sgemm_pack1to4_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) +{ + int w = bottom_blob.w; + int inch = bottom_blob.c; + + int outw = top_blob.w; + int outh = top_blob.h; + const int size = outw * outh; + + const int maxk = kernel_w * kernel_h; + + // im2col + Mat bottom_im2col(size, maxk, inch, 1u, 1, opt.workspace_allocator); + { + const int gap = w * stride_h - outw * stride_w; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < inch; p++) + { + const Mat img = bottom_blob.channel(p); + signed char* ptr = bottom_im2col.channel(p); + + for (int u = 0; u < kernel_h; u++) + { + for (int v = 0; v < kernel_w; v++) + { + const signed char* sptr = img.row(dilation_h * u) + dilation_w * v; + + for (int i = 0; i < outh; i++) + { + int j = 0; + for (; j + 3 < outw; j += 4) + { + ptr[0] = sptr[0]; + ptr[1] = sptr[stride_w]; + ptr[2] = sptr[stride_w * 2]; + ptr[3] = sptr[stride_w * 3]; + + sptr += stride_w * 4; + ptr += 4; + } + for (; j + 1 < outw; j += 2) + { + ptr[0] = sptr[0]; + ptr[1] = sptr[stride_w]; + + sptr += stride_w * 2; + ptr += 2; + } + for (; j < outw; j++) + { + ptr[0] = sptr[0]; + + sptr += stride_w; + ptr += 1; + } + + sptr += gap; + } + } + } + } + } + + im2col_sgemm_pack1to4_int8_sse(bottom_im2col, top_blob, kernel, opt); +} diff --git a/src/layer/x86/convolution_sgemm_pack8to1_int8.h b/src/layer/x86/convolution_sgemm_pack8to1_int8.h new file mode 100644 index 000000000..c2601523c --- /dev/null +++ b/src/layer/x86/convolution_sgemm_pack8to1_int8.h @@ -0,0 +1,839 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void im2col_sgemm_pack8to1_int8_sse(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt) +{ + // Mat bottom_im2col(size, maxk, inch, 8u, 8, opt.workspace_allocator); + + const int size = bottom_im2col.w; + const int maxk = bottom_im2col.h; + const int inch = bottom_im2col.c; + + const int outch = top_blob.c; + + // permute + Mat tmp; +#if __AVX2__ + if (size >= 4) + tmp.create(4 * maxk, inch, size / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator); + else if (size >= 2) + tmp.create(2 * maxk, inch, size / 2 + size % 2, 8u, 8, opt.workspace_allocator); + else + tmp.create(maxk, inch, size, 8u, 8, opt.workspace_allocator); +#else + if (size >= 2) + tmp.create(2 * maxk, inch, size / 2 + size % 2, 8u, 8, opt.workspace_allocator); + else + tmp.create(maxk, inch, size, 8u, 8, opt.workspace_allocator); +#endif + { +#if __AVX2__ + int remain_size_start = 0; + int nn_size = size >> 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 4; + + int64_t* tmpptr = tmp.channel(i / 4); + + for (int q = 0; q < inch; q++) + { + const int64_t* img0 = (const int64_t*)bottom_im2col.channel(q) + i; + + for (int k = 0; k < maxk; k++) + { + __m256i _v = _mm256_loadu_si256((const __m256i*)img0); + _mm256_storeu_si256((__m256i*)tmpptr, _v); + tmpptr += 4; + img0 += size; + } + } + } + + remain_size_start += nn_size << 2; + nn_size = (size - remain_size_start) >> 1; +#else + int remain_size_start = 0; + int nn_size = (size - remain_size_start) >> 1; +#endif + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 2; + +#if __AVX2__ + int64_t* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); +#else + int64_t* tmpptr = tmp.channel(i / 2); +#endif + + for (int q = 0; q < inch; q++) + { + const int64_t* img0 = (const int64_t*)bottom_im2col.channel(q) + i; + + for (int k = 0; k < maxk; k++) + { + __m128i _v = _mm_loadu_si128((const __m128i*)img0); + _mm_storeu_si128((__m128i*)tmpptr, _v); + tmpptr += 2; + img0 += size; + } + } + } + + remain_size_start += nn_size << 1; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = remain_size_start; i < size; i++) + { +#if __AVX2__ + int64_t* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); +#else + int64_t* tmpptr = tmp.channel(i / 2 + i % 2); +#endif + + for (int q = 0; q < inch; q++) + { + const int64_t* img0 = (const int64_t*)bottom_im2col.channel(q) + i; + + for (int k = 0; k < maxk; k++) + { + tmpptr[0] = img0[0]; + tmpptr += 1; + img0 += size; + } + } + } + } + + int nn_outch = 0; + int remain_outch_start = 0; + + nn_outch = outch >> 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int pp = 0; pp < nn_outch; pp++) + { + int p = pp * 4; + + int* outptr0 = top_blob.channel(p); + int* outptr1 = top_blob.channel(p + 1); + int* outptr2 = top_blob.channel(p + 2); + int* outptr3 = top_blob.channel(p + 3); + + int i = 0; +#if __AVX2__ + for (; i + 3 < size; i += 4) + { + const signed char* tmpptr = tmp.channel(i / 4); + const signed char* kptr0 = kernel.channel(p / 4); + + int nn = inch * maxk; // inch always > 0 + + __m256i _sum00_11 = _mm256_setzero_si256(); + __m256i _sum10_01 = _mm256_setzero_si256(); + __m256i _sum02_13 = _mm256_setzero_si256(); + __m256i _sum12_03 = _mm256_setzero_si256(); + + __m256i _sum04_15 = _mm256_setzero_si256(); + __m256i _sum14_05 = _mm256_setzero_si256(); + __m256i _sum06_17 = _mm256_setzero_si256(); + __m256i _sum16_07 = _mm256_setzero_si256(); + + int j = 0; + for (; j < nn; j++) + { + __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr); + __m256i _val01_16 = _mm256_cvtepi8_epi16(_val01); + + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m128i _w23 = _mm_loadu_si128((const __m128i*)(kptr0 + 16)); + __m256i _w01_16 = _mm256_cvtepi8_epi16(_w01); + __m256i _w23_16 = _mm256_cvtepi8_epi16(_w23); + + __m256i _val10_16 = _mm256_permute4x64_epi64(_val01_16, 78); + + __m256i _sl00_11 = _mm256_mullo_epi16(_val01_16, _w01_16); + __m256i _sh00_11 = _mm256_mulhi_epi16(_val01_16, _w01_16); + __m256i _sl10_01 = _mm256_mullo_epi16(_val10_16, _w01_16); + __m256i _sh10_01 = _mm256_mulhi_epi16(_val10_16, _w01_16); + __m256i _sl02_13 = _mm256_mullo_epi16(_val01_16, _w23_16); + __m256i _sh02_13 = _mm256_mulhi_epi16(_val01_16, _w23_16); + __m256i _sl12_03 = _mm256_mullo_epi16(_val10_16, _w23_16); + __m256i _sh12_03 = _mm256_mulhi_epi16(_val10_16, _w23_16); + + _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_unpacklo_epi16(_sl00_11, _sh00_11)); + _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_unpacklo_epi16(_sl10_01, _sh10_01)); + _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_unpacklo_epi16(_sl02_13, _sh02_13)); + _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_unpacklo_epi16(_sl12_03, _sh12_03)); + _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_unpackhi_epi16(_sl00_11, _sh00_11)); + _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_unpackhi_epi16(_sl10_01, _sh10_01)); + _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_unpackhi_epi16(_sl02_13, _sh02_13)); + _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_unpackhi_epi16(_sl12_03, _sh12_03)); + + __m128i _val23 = _mm_loadu_si128((const __m128i*)(tmpptr + 16)); + __m256i _val23_16 = _mm256_cvtepi8_epi16(_val23); + __m256i _val32_16 = _mm256_permute4x64_epi64(_val23_16, 78); + + __m256i _sl04_15 = _mm256_mullo_epi16(_val23_16, _w01_16); + __m256i _sh04_15 = _mm256_mulhi_epi16(_val23_16, _w01_16); + __m256i _sl14_05 = _mm256_mullo_epi16(_val32_16, _w01_16); + __m256i _sh14_05 = _mm256_mulhi_epi16(_val32_16, _w01_16); + __m256i _sl06_17 = _mm256_mullo_epi16(_val23_16, _w23_16); + __m256i _sh06_17 = _mm256_mulhi_epi16(_val23_16, _w23_16); + __m256i _sl16_07 = _mm256_mullo_epi16(_val32_16, _w23_16); + __m256i _sh16_07 = _mm256_mulhi_epi16(_val32_16, _w23_16); + + _sum04_15 = _mm256_add_epi32(_sum04_15, _mm256_unpacklo_epi16(_sl04_15, _sh04_15)); + _sum14_05 = _mm256_add_epi32(_sum14_05, _mm256_unpacklo_epi16(_sl14_05, _sh14_05)); + _sum06_17 = _mm256_add_epi32(_sum06_17, _mm256_unpacklo_epi16(_sl06_17, _sh06_17)); + _sum16_07 = _mm256_add_epi32(_sum16_07, _mm256_unpacklo_epi16(_sl16_07, _sh16_07)); + _sum04_15 = _mm256_add_epi32(_sum04_15, _mm256_unpackhi_epi16(_sl04_15, _sh04_15)); + _sum14_05 = _mm256_add_epi32(_sum14_05, _mm256_unpackhi_epi16(_sl14_05, _sh14_05)); + _sum06_17 = _mm256_add_epi32(_sum06_17, _mm256_unpackhi_epi16(_sl06_17, _sh06_17)); + _sum16_07 = _mm256_add_epi32(_sum16_07, _mm256_unpackhi_epi16(_sl16_07, _sh16_07)); + + tmpptr += 32; + kptr0 += 32; + } + + // transpose 4x8 + { + __m256i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm256_unpacklo_epi32(_sum00_11, _sum10_01); + _tmp1 = _mm256_unpacklo_epi32(_sum02_13, _sum12_03); + _tmp2 = _mm256_unpackhi_epi32(_sum00_11, _sum10_01); + _tmp3 = _mm256_unpackhi_epi32(_sum02_13, _sum12_03); + _sum00_11 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _sum10_01 = _mm256_unpackhi_epi64(_tmp0, _tmp1); + _sum02_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); + _sum12_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); + } + { + __m256i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm256_unpacklo_epi32(_sum04_15, _sum14_05); + _tmp1 = _mm256_unpacklo_epi32(_sum06_17, _sum16_07); + _tmp2 = _mm256_unpackhi_epi32(_sum04_15, _sum14_05); + _tmp3 = _mm256_unpackhi_epi32(_sum06_17, _sum16_07); + _sum04_15 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _sum14_05 = _mm256_unpackhi_epi64(_tmp0, _tmp1); + _sum06_17 = _mm256_unpacklo_epi64(_tmp2, _tmp3); + _sum16_07 = _mm256_unpackhi_epi64(_tmp2, _tmp3); + } + + _sum00_11 = _mm256_add_epi32(_sum00_11, _sum10_01); + _sum02_13 = _mm256_add_epi32(_sum02_13, _sum12_03); + _sum00_11 = _mm256_add_epi32(_sum00_11, _sum02_13); + + _sum04_15 = _mm256_add_epi32(_sum04_15, _sum14_05); + _sum06_17 = _mm256_add_epi32(_sum06_17, _sum16_07); + _sum04_15 = _mm256_add_epi32(_sum04_15, _sum06_17); + + __m256i _perm_mask = _mm256_set_epi32(6, 3, 4, 1, 7, 2, 5, 0); + _sum00_11 = _mm256_permutevar8x32_epi32(_sum00_11, _perm_mask); + _sum04_15 = _mm256_permutevar8x32_epi32(_sum04_15, _perm_mask); + + int sum[16]; + _mm256_storeu_si256((__m256i*)sum, _sum00_11); + _mm256_storeu_si256((__m256i*)(sum + 8), _sum04_15); + + outptr0[0] = sum[0]; + outptr1[0] = sum[1]; + outptr2[0] = sum[2]; + outptr3[0] = sum[3]; + outptr0[1] = sum[4]; + outptr1[1] = sum[5]; + outptr2[1] = sum[6]; + outptr3[1] = sum[7]; + outptr0[2] = sum[8]; + outptr1[2] = sum[9]; + outptr2[2] = sum[10]; + outptr3[2] = sum[11]; + outptr0[3] = sum[12]; + outptr1[3] = sum[13]; + outptr2[3] = sum[14]; + outptr3[3] = sum[15]; + outptr0 += 4; + outptr1 += 4; + outptr2 += 4; + outptr3 += 4; + } +#endif + for (; i + 1 < size; i += 2) + { +#if __AVX2__ + const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); +#else + const signed char* tmpptr = tmp.channel(i / 2); +#endif + const signed char* kptr0 = kernel.channel(p / 4); + + int nn = inch * maxk; // inch always > 0 + +#if __AVX2__ + __m256i _sum00_11 = _mm256_setzero_si256(); + __m256i _sum10_01 = _mm256_setzero_si256(); + __m256i _sum02_13 = _mm256_setzero_si256(); + __m256i _sum12_03 = _mm256_setzero_si256(); +#else + __m128i _sum00 = _mm_setzero_si128(); + __m128i _sum01 = _mm_setzero_si128(); + __m128i _sum02 = _mm_setzero_si128(); + __m128i _sum03 = _mm_setzero_si128(); + __m128i _sum10 = _mm_setzero_si128(); + __m128i _sum11 = _mm_setzero_si128(); + __m128i _sum12 = _mm_setzero_si128(); + __m128i _sum13 = _mm_setzero_si128(); +#endif + + int j = 0; + for (; j < nn; j++) + { +#if __AVX2__ + __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr); + __m256i _val01_16 = _mm256_cvtepi8_epi16(_val01); + + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m128i _w23 = _mm_loadu_si128((const __m128i*)(kptr0 + 16)); + __m256i _w01_16 = _mm256_cvtepi8_epi16(_w01); + __m256i _w23_16 = _mm256_cvtepi8_epi16(_w23); + + __m256i _val10_16 = _mm256_permute4x64_epi64(_val01_16, 78); + + __m256i _sl00_11 = _mm256_mullo_epi16(_val01_16, _w01_16); + __m256i _sh00_11 = _mm256_mulhi_epi16(_val01_16, _w01_16); + __m256i _sl10_01 = _mm256_mullo_epi16(_val10_16, _w01_16); + __m256i _sh10_01 = _mm256_mulhi_epi16(_val10_16, _w01_16); + __m256i _sl02_13 = _mm256_mullo_epi16(_val01_16, _w23_16); + __m256i _sh02_13 = _mm256_mulhi_epi16(_val01_16, _w23_16); + __m256i _sl12_03 = _mm256_mullo_epi16(_val10_16, _w23_16); + __m256i _sh12_03 = _mm256_mulhi_epi16(_val10_16, _w23_16); + + _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_unpacklo_epi16(_sl00_11, _sh00_11)); + _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_unpacklo_epi16(_sl10_01, _sh10_01)); + _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_unpacklo_epi16(_sl02_13, _sh02_13)); + _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_unpacklo_epi16(_sl12_03, _sh12_03)); + _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_unpackhi_epi16(_sl00_11, _sh00_11)); + _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_unpackhi_epi16(_sl10_01, _sh10_01)); + _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_unpackhi_epi16(_sl02_13, _sh02_13)); + _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_unpackhi_epi16(_sl12_03, _sh12_03)); +#else + // TODO use _mm_cvtepi8_epi16 on sse4.1 + __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr); + __m128i _extval01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _val01); + __m128i _val0 = _mm_unpacklo_epi8(_val01, _extval01); + __m128i _val1 = _mm_unpackhi_epi8(_val01, _extval01); + + // TODO use _mm_cvtepi8_epi16 on sse4.1 + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m128i _w23 = _mm_loadu_si128((const __m128i*)(kptr0 + 16)); + __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01); + __m128i _extw23 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w23); + __m128i _w0 = _mm_unpacklo_epi8(_w01, _extw01); + __m128i _w1 = _mm_unpackhi_epi8(_w01, _extw01); + __m128i _w2 = _mm_unpacklo_epi8(_w23, _extw23); + __m128i _w3 = _mm_unpackhi_epi8(_w23, _extw23); + + __m128i _sl00 = _mm_mullo_epi16(_val0, _w0); + __m128i _sh00 = _mm_mulhi_epi16(_val0, _w0); + __m128i _sl01 = _mm_mullo_epi16(_val0, _w1); + __m128i _sh01 = _mm_mulhi_epi16(_val0, _w1); + __m128i _sl02 = _mm_mullo_epi16(_val0, _w2); + __m128i _sh02 = _mm_mulhi_epi16(_val0, _w2); + __m128i _sl03 = _mm_mullo_epi16(_val0, _w3); + __m128i _sh03 = _mm_mulhi_epi16(_val0, _w3); + __m128i _sl10 = _mm_mullo_epi16(_val1, _w0); + __m128i _sh10 = _mm_mulhi_epi16(_val1, _w0); + __m128i _sl11 = _mm_mullo_epi16(_val1, _w1); + __m128i _sh11 = _mm_mulhi_epi16(_val1, _w1); + __m128i _sl12 = _mm_mullo_epi16(_val1, _w2); + __m128i _sh12 = _mm_mulhi_epi16(_val1, _w2); + __m128i _sl13 = _mm_mullo_epi16(_val1, _w3); + __m128i _sh13 = _mm_mulhi_epi16(_val1, _w3); + + _sum00 = _mm_add_epi32(_sum00, _mm_unpacklo_epi16(_sl00, _sh00)); + _sum01 = _mm_add_epi32(_sum01, _mm_unpacklo_epi16(_sl01, _sh01)); + _sum02 = _mm_add_epi32(_sum02, _mm_unpacklo_epi16(_sl02, _sh02)); + _sum03 = _mm_add_epi32(_sum03, _mm_unpacklo_epi16(_sl03, _sh03)); + _sum00 = _mm_add_epi32(_sum00, _mm_unpackhi_epi16(_sl00, _sh00)); + _sum01 = _mm_add_epi32(_sum01, _mm_unpackhi_epi16(_sl01, _sh01)); + _sum02 = _mm_add_epi32(_sum02, _mm_unpackhi_epi16(_sl02, _sh02)); + _sum03 = _mm_add_epi32(_sum03, _mm_unpackhi_epi16(_sl03, _sh03)); + _sum10 = _mm_add_epi32(_sum10, _mm_unpacklo_epi16(_sl10, _sh10)); + _sum11 = _mm_add_epi32(_sum11, _mm_unpacklo_epi16(_sl11, _sh11)); + _sum12 = _mm_add_epi32(_sum12, _mm_unpacklo_epi16(_sl12, _sh12)); + _sum13 = _mm_add_epi32(_sum13, _mm_unpacklo_epi16(_sl13, _sh13)); + _sum10 = _mm_add_epi32(_sum10, _mm_unpackhi_epi16(_sl10, _sh10)); + _sum11 = _mm_add_epi32(_sum11, _mm_unpackhi_epi16(_sl11, _sh11)); + _sum12 = _mm_add_epi32(_sum12, _mm_unpackhi_epi16(_sl12, _sh12)); + _sum13 = _mm_add_epi32(_sum13, _mm_unpackhi_epi16(_sl13, _sh13)); +#endif + + tmpptr += 16; + kptr0 += 32; + } + +#if __AVX2__ + // transpose 4x8 + { + __m256i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm256_unpacklo_epi32(_sum00_11, _sum10_01); + _tmp1 = _mm256_unpacklo_epi32(_sum02_13, _sum12_03); + _tmp2 = _mm256_unpackhi_epi32(_sum00_11, _sum10_01); + _tmp3 = _mm256_unpackhi_epi32(_sum02_13, _sum12_03); + _sum00_11 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _sum10_01 = _mm256_unpackhi_epi64(_tmp0, _tmp1); + _sum02_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); + _sum12_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); + } + + _sum00_11 = _mm256_add_epi32(_sum00_11, _sum10_01); + _sum02_13 = _mm256_add_epi32(_sum02_13, _sum12_03); + _sum00_11 = _mm256_add_epi32(_sum00_11, _sum02_13); + + __m256i _perm_mask = _mm256_set_epi32(6, 3, 4, 1, 7, 2, 5, 0); + _sum00_11 = _mm256_permutevar8x32_epi32(_sum00_11, _perm_mask); + + int sum[8]; + _mm256_storeu_si256((__m256i*)sum, _sum00_11); +#else + // transpose 4x4 + { + __m128i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm_unpacklo_epi32(_sum00, _sum01); + _tmp1 = _mm_unpacklo_epi32(_sum02, _sum03); + _tmp2 = _mm_unpackhi_epi32(_sum00, _sum01); + _tmp3 = _mm_unpackhi_epi32(_sum02, _sum03); + _sum00 = _mm_unpacklo_epi64(_tmp0, _tmp1); + _sum01 = _mm_unpackhi_epi64(_tmp0, _tmp1); + _sum02 = _mm_unpacklo_epi64(_tmp2, _tmp3); + _sum03 = _mm_unpackhi_epi64(_tmp2, _tmp3); + } + { + __m128i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm_unpacklo_epi32(_sum10, _sum11); + _tmp1 = _mm_unpacklo_epi32(_sum12, _sum13); + _tmp2 = _mm_unpackhi_epi32(_sum10, _sum11); + _tmp3 = _mm_unpackhi_epi32(_sum12, _sum13); + _sum10 = _mm_unpacklo_epi64(_tmp0, _tmp1); + _sum11 = _mm_unpackhi_epi64(_tmp0, _tmp1); + _sum12 = _mm_unpacklo_epi64(_tmp2, _tmp3); + _sum13 = _mm_unpackhi_epi64(_tmp2, _tmp3); + } + + _sum00 = _mm_add_epi32(_sum00, _sum01); + _sum02 = _mm_add_epi32(_sum02, _sum03); + _sum10 = _mm_add_epi32(_sum10, _sum11); + _sum12 = _mm_add_epi32(_sum12, _sum13); + + _sum00 = _mm_add_epi32(_sum00, _sum02); + _sum10 = _mm_add_epi32(_sum10, _sum12); + + int sum[8]; + _mm_storeu_si128((__m128i*)sum, _sum00); + _mm_storeu_si128((__m128i*)(sum + 4), _sum10); +#endif + + outptr0[0] = sum[0]; + outptr1[0] = sum[1]; + outptr2[0] = sum[2]; + outptr3[0] = sum[3]; + outptr0[1] = sum[4]; + outptr1[1] = sum[5]; + outptr2[1] = sum[6]; + outptr3[1] = sum[7]; + outptr0 += 2; + outptr1 += 2; + outptr2 += 2; + outptr3 += 2; + } + for (; i < size; i++) + { +#if __AVX2__ + const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); +#else + const signed char* tmpptr = tmp.channel(i / 2 + i % 2); +#endif + const signed char* kptr0 = kernel.channel(p / 4); + + int nn = inch * maxk; // inch always > 0 + + __m128i _sum0 = _mm_setzero_si128(); + __m128i _sum1 = _mm_setzero_si128(); + __m128i _sum2 = _mm_setzero_si128(); + __m128i _sum3 = _mm_setzero_si128(); + + int j = 0; + for (; j < nn; j++) + { + // TODO use _mm_cvtepi8_epi16 on sse4.1 + __m128i _val = _mm_loadl_epi64((const __m128i*)tmpptr); + _val = _mm_unpacklo_epi8(_val, _mm_cmpgt_epi8(_mm_setzero_si128(), _val)); + + // TODO use _mm_cvtepi8_epi16 on sse4.1 + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m128i _w23 = _mm_loadu_si128((const __m128i*)(kptr0 + 16)); + __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01); + __m128i _extw23 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w23); + __m128i _w0 = _mm_unpacklo_epi8(_w01, _extw01); + __m128i _w1 = _mm_unpackhi_epi8(_w01, _extw01); + __m128i _w2 = _mm_unpacklo_epi8(_w23, _extw23); + __m128i _w3 = _mm_unpackhi_epi8(_w23, _extw23); + + __m128i _sl0 = _mm_mullo_epi16(_val, _w0); + __m128i _sh0 = _mm_mulhi_epi16(_val, _w0); + __m128i _sl1 = _mm_mullo_epi16(_val, _w1); + __m128i _sh1 = _mm_mulhi_epi16(_val, _w1); + __m128i _sl2 = _mm_mullo_epi16(_val, _w2); + __m128i _sh2 = _mm_mulhi_epi16(_val, _w2); + __m128i _sl3 = _mm_mullo_epi16(_val, _w3); + __m128i _sh3 = _mm_mulhi_epi16(_val, _w3); + + _sum0 = _mm_add_epi32(_sum0, _mm_unpacklo_epi16(_sl0, _sh0)); + _sum1 = _mm_add_epi32(_sum1, _mm_unpacklo_epi16(_sl1, _sh1)); + _sum2 = _mm_add_epi32(_sum2, _mm_unpacklo_epi16(_sl2, _sh2)); + _sum3 = _mm_add_epi32(_sum3, _mm_unpacklo_epi16(_sl3, _sh3)); + _sum0 = _mm_add_epi32(_sum0, _mm_unpackhi_epi16(_sl0, _sh0)); + _sum1 = _mm_add_epi32(_sum1, _mm_unpackhi_epi16(_sl1, _sh1)); + _sum2 = _mm_add_epi32(_sum2, _mm_unpackhi_epi16(_sl2, _sh2)); + _sum3 = _mm_add_epi32(_sum3, _mm_unpackhi_epi16(_sl3, _sh3)); + + tmpptr += 8; + kptr0 += 32; + } + + // transpose 4x4 + { + __m128i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1); + _tmp1 = _mm_unpacklo_epi32(_sum2, _sum3); + _tmp2 = _mm_unpackhi_epi32(_sum0, _sum1); + _tmp3 = _mm_unpackhi_epi32(_sum2, _sum3); + _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1); + _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp1); + _sum2 = _mm_unpacklo_epi64(_tmp2, _tmp3); + _sum3 = _mm_unpackhi_epi64(_tmp2, _tmp3); + } + + _sum0 = _mm_add_epi32(_sum0, _sum1); + _sum2 = _mm_add_epi32(_sum2, _sum3); + + _sum0 = _mm_add_epi32(_sum0, _sum2); + + int sum[4]; + _mm_storeu_si128((__m128i*)sum, _sum0); + + outptr0[0] = sum[0]; + outptr1[0] = sum[1]; + outptr2[0] = sum[2]; + outptr3[0] = sum[3]; + outptr0 += 1; + outptr1 += 1; + outptr2 += 1; + outptr3 += 1; + } + } + + remain_outch_start += nn_outch << 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = remain_outch_start; p < outch; p++) + { + int* outptr0 = top_blob.channel(p); + + int i = 0; +#if __AVX2__ + for (; i + 3 < size; i += 4) + { + const signed char* tmpptr = tmp.channel(i / 4); + const signed char* kptr0 = kernel.channel(p / 4 + p % 4); + + int nn = inch * maxk; // inch always > 0 + + __m256i _sum0_2 = _mm256_setzero_si256(); + __m256i _sum1_3 = _mm256_setzero_si256(); + __m256i _sum4_6 = _mm256_setzero_si256(); + __m256i _sum5_7 = _mm256_setzero_si256(); + + int j = 0; + for (; j < nn; j++) + { + __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr); + __m128i _val23 = _mm_loadu_si128((const __m128i*)(tmpptr + 16)); + __m256i _val01_16 = _mm256_cvtepi8_epi16(_val01); + __m256i _val23_16 = _mm256_cvtepi8_epi16(_val23); + + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m256i _w01_16 = _mm256_cvtepi8_epi16(_w01); + _w01_16 = _mm256_permute4x64_epi64(_w01_16, _MM_SHUFFLE(1, 0, 1, 0)); + + __m256i _sl00_10 = _mm256_mullo_epi16(_val01_16, _w01_16); + __m256i _sh00_10 = _mm256_mulhi_epi16(_val01_16, _w01_16); + __m256i _sl20_30 = _mm256_mullo_epi16(_val23_16, _w01_16); + __m256i _sh20_30 = _mm256_mulhi_epi16(_val23_16, _w01_16); + + _sum0_2 = _mm256_add_epi32(_sum0_2, _mm256_unpacklo_epi16(_sl00_10, _sh00_10)); + _sum1_3 = _mm256_add_epi32(_sum1_3, _mm256_unpackhi_epi16(_sl00_10, _sh00_10)); + _sum4_6 = _mm256_add_epi32(_sum4_6, _mm256_unpacklo_epi16(_sl20_30, _sh20_30)); + _sum5_7 = _mm256_add_epi32(_sum5_7, _mm256_unpackhi_epi16(_sl20_30, _sh20_30)); + + tmpptr += 32; + kptr0 += 8; + } + + _sum0_2 = _mm256_add_epi32(_sum0_2, _sum1_3); + _sum4_6 = _mm256_add_epi32(_sum4_6, _sum5_7); + __m128i _sum0 = _mm256_extracti128_si256(_sum0_2, 0); + __m128i _sum2 = _mm256_extracti128_si256(_sum0_2, 1); + __m128i _sum4 = _mm256_extracti128_si256(_sum4_6, 1); + __m128i _sum6 = _mm256_extracti128_si256(_sum4_6, 1); + + outptr0[0] = _mm_reduce_add_epi32(_sum0); + outptr0[1] = _mm_reduce_add_epi32(_sum2); + outptr0[2] = _mm_reduce_add_epi32(_sum4); + outptr0[3] = _mm_reduce_add_epi32(_sum6); + outptr0 += 4; + } +#endif + for (; i + 1 < size; i += 2) + { +#if __AVX2__ + const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); +#else + const signed char* tmpptr = tmp.channel(i / 2); +#endif + const signed char* kptr0 = kernel.channel(p / 4 + p % 4); + + int nn = inch * maxk; // inch always > 0 + +#if __AVX2__ + __m256i _sum0_2 = _mm256_setzero_si256(); + __m256i _sum1_3 = _mm256_setzero_si256(); +#else + __m128i _sum0 = _mm_setzero_si128(); + __m128i _sum1 = _mm_setzero_si128(); + __m128i _sum2 = _mm_setzero_si128(); + __m128i _sum3 = _mm_setzero_si128(); +#endif + + int j = 0; + for (; j < nn; j++) + { +#if __AVX2__ + __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr); + __m256i _val01_16 = _mm256_cvtepi8_epi16(_val01); + + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m256i _w01_16 = _mm256_cvtepi8_epi16(_w01); + _w01_16 = _mm256_permute4x64_epi64(_w01_16, _MM_SHUFFLE(1, 0, 1, 0)); + + __m256i _sl00_10 = _mm256_mullo_epi16(_val01_16, _w01_16); + __m256i _sh00_10 = _mm256_mulhi_epi16(_val01_16, _w01_16); + + _sum0_2 = _mm256_add_epi32(_sum0_2, _mm256_unpacklo_epi16(_sl00_10, _sh00_10)); + _sum1_3 = _mm256_add_epi32(_sum1_3, _mm256_unpackhi_epi16(_sl00_10, _sh00_10)); +#else + // TODO use _mm_cvtepi8_epi16 on sse4.1 + __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr); + __m128i _extval01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _val01); + __m128i _val0 = _mm_unpacklo_epi8(_val01, _extval01); + __m128i _val1 = _mm_unpackhi_epi8(_val01, _extval01); + + // TODO use _mm_cvtepi8_epi16 on sse4.1 + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01); + __m128i _w0 = _mm_unpacklo_epi8(_w01, _extw01); + + __m128i _sl00 = _mm_mullo_epi16(_val0, _w0); + __m128i _sh00 = _mm_mulhi_epi16(_val0, _w0); + __m128i _sl10 = _mm_mullo_epi16(_val1, _w0); + __m128i _sh10 = _mm_mulhi_epi16(_val1, _w0); + + _sum0 = _mm_add_epi32(_sum0, _mm_unpacklo_epi16(_sl00, _sh00)); + _sum1 = _mm_add_epi32(_sum1, _mm_unpackhi_epi16(_sl00, _sh00)); + _sum2 = _mm_add_epi32(_sum2, _mm_unpacklo_epi16(_sl10, _sh10)); + _sum3 = _mm_add_epi32(_sum3, _mm_unpackhi_epi16(_sl10, _sh10)); +#endif + + tmpptr += 16; + kptr0 += 8; + } + +#if __AVX2__ + _sum0_2 = _mm256_add_epi32(_sum0_2, _sum1_3); + __m128i _sum0 = _mm256_extracti128_si256(_sum0_2, 0); + __m128i _sum2 = _mm256_extracti128_si256(_sum0_2, 1); +#else + _sum0 = _mm_add_epi32(_sum0, _sum1); + _sum2 = _mm_add_epi32(_sum2, _sum3); +#endif + + outptr0[0] = _mm_reduce_add_epi32(_sum0); + outptr0[1] = _mm_reduce_add_epi32(_sum2); + outptr0 += 2; + } + for (; i < size; i++) + { +#if __AVX2__ + const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); +#else + const signed char* tmpptr = tmp.channel(i / 2 + i % 2); +#endif + const signed char* kptr0 = kernel.channel(p / 4 + p % 4); + + int nn = inch * maxk; // inch always > 0 + + __m128i _sum0 = _mm_setzero_si128(); + __m128i _sum1 = _mm_setzero_si128(); + + int j = 0; + for (; j < nn; j++) + { + // TODO use _mm_cvtepi8_epi16 on sse4.1 + __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr); + __m128i _extval01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _val01); + __m128i _val0 = _mm_unpacklo_epi8(_val01, _extval01); + + // TODO use _mm_cvtepi8_epi16 on sse4.1 + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01); + __m128i _w0 = _mm_unpacklo_epi8(_w01, _extw01); + + __m128i _sl00 = _mm_mullo_epi16(_val0, _w0); + __m128i _sh00 = _mm_mulhi_epi16(_val0, _w0); + + _sum0 = _mm_add_epi32(_sum0, _mm_unpacklo_epi16(_sl00, _sh00)); + _sum1 = _mm_add_epi32(_sum1, _mm_unpackhi_epi16(_sl00, _sh00)); + + tmpptr += 8; + kptr0 += 8; + } + + _sum0 = _mm_add_epi32(_sum0, _sum1); + + outptr0[0] = _mm_reduce_add_epi32(_sum0); + outptr0 += 1; + } + } +} + +static void convolution_im2col_sgemm_transform_kernel_pack8to1_int8_sse(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) +{ + const int maxk = kernel_w * kernel_h; + + // interleave + // src = maxk-inch-outch + // dst = 8a-4b-maxk-inch/8a-outch/4b + Mat kernel = _kernel.reshape(maxk, inch, outch); + if (outch >= 4) + kernel_tm.create(32 * maxk, inch / 8, outch / 4 + outch % 4, (size_t)1u); + else + kernel_tm.create(8 * maxk, inch / 8, outch, (size_t)1u); + + int q = 0; + for (; q + 3 < outch; q += 4) + { + signed char* g00 = kernel_tm.channel(q / 4); + + for (int p = 0; p + 7 < inch; p += 8) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 8; j++) + { + const signed char* k00 = kernel.channel(q + i).row(p + j); + + g00[0] = k00[k]; + + g00++; + } + } + } + } + } + // TODO unroll 2 + for (; q < outch; q++) + { + signed char* g00 = kernel_tm.channel(q / 4 + q % 4); + + for (int p = 0; p + 7 < inch; p += 8) + { + for (int k = 0; k < maxk; k++) + { + for (int j = 0; j < 8; j++) + { + const signed char* k00 = kernel.channel(q).row(p + j); + + g00[0] = k00[k]; + + g00++; + } + } + } + } +} + +static void convolution_im2col_sgemm_pack8to1_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) +{ + int w = bottom_blob.w; + int inch = bottom_blob.c; + + int outw = top_blob.w; + int outh = top_blob.h; + const int size = outw * outh; + + const int maxk = kernel_w * kernel_h; + + // im2col + Mat bottom_im2col(size, maxk, inch, 8u, 8, opt.workspace_allocator); + { + const int gap = w * stride_h - outw * stride_w; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < inch; p++) + { + const Mat img = bottom_blob.channel(p); + int64_t* ptr = bottom_im2col.channel(p); + + for (int u = 0; u < kernel_h; u++) + { + for (int v = 0; v < kernel_w; v++) + { + const int64_t* sptr = img.row(dilation_h * u) + dilation_w * v; + + for (int i = 0; i < outh; i++) + { + int j = 0; + for (; j < outw; j++) + { + ptr[0] = sptr[0]; + + sptr += stride_w; + ptr += 1; + } + + sptr += gap; + } + } + } + } + } + + im2col_sgemm_pack8to1_int8_sse(bottom_im2col, top_blob, kernel, opt); +} diff --git a/src/layer/x86/convolution_sgemm_pack8to4_int8.h b/src/layer/x86/convolution_sgemm_pack8to4_int8.h index 0d66dc07a..0bbb2c5a7 100644 --- a/src/layer/x86/convolution_sgemm_pack8to4_int8.h +++ b/src/layer/x86/convolution_sgemm_pack8to4_int8.h @@ -24,20 +24,62 @@ static void im2col_sgemm_pack8to4_int8_sse(const Mat& bottom_im2col, Mat& top_bl // permute Mat tmp; +#if __AVX2__ + if (size >= 4) + tmp.create(4 * maxk, inch, size / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator); + else if (size >= 2) + tmp.create(2 * maxk, inch, size / 2 + size % 2, 8u, 8, opt.workspace_allocator); + else + tmp.create(maxk, inch, size, 8u, 8, opt.workspace_allocator); +#else if (size >= 2) tmp.create(2 * maxk, inch, size / 2 + size % 2, 8u, 8, opt.workspace_allocator); else tmp.create(maxk, inch, size, 8u, 8, opt.workspace_allocator); +#endif { +#if __AVX2__ + int remain_size_start = 0; + int nn_size = size >> 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 4; + + int64_t* tmpptr = tmp.channel(i / 4); + + for (int q = 0; q < inch; q++) + { + const int64_t* img0 = (const int64_t*)bottom_im2col.channel(q) + i; + + for (int k = 0; k < maxk; k++) + { + __m256i _v = _mm256_loadu_si256((const __m256i*)img0); + _mm256_storeu_si256((__m256i*)tmpptr, _v); + tmpptr += 4; + img0 += size; + } + } + } + + remain_size_start += nn_size << 2; + nn_size = (size - remain_size_start) >> 1; +#else int remain_size_start = 0; int nn_size = size >> 1; +#endif #pragma omp parallel for num_threads(opt.num_threads) for (int ii = 0; ii < nn_size; ii++) { int i = remain_size_start + ii * 2; +#if __AVX2__ + int64_t* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); +#else int64_t* tmpptr = tmp.channel(i / 2); +#endif for (int q = 0; q < inch; q++) { @@ -58,7 +100,11 @@ static void im2col_sgemm_pack8to4_int8_sse(const Mat& bottom_im2col, Mat& top_bl #pragma omp parallel for num_threads(opt.num_threads) for (int i = remain_size_start; i < size; i++) { +#if __AVX2__ + int64_t* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); +#else int64_t* tmpptr = tmp.channel(i / 2 + i % 2); +#endif for (int q = 0; q < inch; q++) { @@ -80,13 +126,139 @@ static void im2col_sgemm_pack8to4_int8_sse(const Mat& bottom_im2col, Mat& top_bl int* outptr0 = top_blob.channel(p); int i = 0; +#if __AVX2__ + for (; i + 3 < size; i += 4) + { + const signed char* tmpptr = tmp.channel(i / 4); + const signed char* kptr0 = kernel.channel(p); + + int nn = inch * maxk; // inch always > 0 + + __m256i _sum00_11 = _mm256_setzero_si256(); + __m256i _sum10_01 = _mm256_setzero_si256(); + __m256i _sum02_13 = _mm256_setzero_si256(); + __m256i _sum12_03 = _mm256_setzero_si256(); + + __m256i _sum04_15 = _mm256_setzero_si256(); + __m256i _sum14_05 = _mm256_setzero_si256(); + __m256i _sum06_17 = _mm256_setzero_si256(); + __m256i _sum16_07 = _mm256_setzero_si256(); + + int j = 0; + for (; j < nn; j++) + { + __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr); + __m256i _val01_16 = _mm256_cvtepi8_epi16(_val01); + + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m128i _w23 = _mm_loadu_si128((const __m128i*)(kptr0 + 16)); + __m256i _w01_16 = _mm256_cvtepi8_epi16(_w01); + __m256i _w23_16 = _mm256_cvtepi8_epi16(_w23); + + __m256i _val10_16 = _mm256_permute4x64_epi64(_val01_16, 78); + + __m256i _sl00_11 = _mm256_mullo_epi16(_val01_16, _w01_16); + __m256i _sh00_11 = _mm256_mulhi_epi16(_val01_16, _w01_16); + __m256i _sl10_01 = _mm256_mullo_epi16(_val10_16, _w01_16); + __m256i _sh10_01 = _mm256_mulhi_epi16(_val10_16, _w01_16); + __m256i _sl02_13 = _mm256_mullo_epi16(_val01_16, _w23_16); + __m256i _sh02_13 = _mm256_mulhi_epi16(_val01_16, _w23_16); + __m256i _sl12_03 = _mm256_mullo_epi16(_val10_16, _w23_16); + __m256i _sh12_03 = _mm256_mulhi_epi16(_val10_16, _w23_16); + + _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_unpacklo_epi16(_sl00_11, _sh00_11)); + _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_unpacklo_epi16(_sl10_01, _sh10_01)); + _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_unpacklo_epi16(_sl02_13, _sh02_13)); + _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_unpacklo_epi16(_sl12_03, _sh12_03)); + _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_unpackhi_epi16(_sl00_11, _sh00_11)); + _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_unpackhi_epi16(_sl10_01, _sh10_01)); + _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_unpackhi_epi16(_sl02_13, _sh02_13)); + _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_unpackhi_epi16(_sl12_03, _sh12_03)); + + __m128i _val23 = _mm_loadu_si128((const __m128i*)(tmpptr + 16)); + __m256i _val23_16 = _mm256_cvtepi8_epi16(_val23); + __m256i _val32_16 = _mm256_permute4x64_epi64(_val23_16, 78); + + __m256i _sl04_15 = _mm256_mullo_epi16(_val23_16, _w01_16); + __m256i _sh04_15 = _mm256_mulhi_epi16(_val23_16, _w01_16); + __m256i _sl14_05 = _mm256_mullo_epi16(_val32_16, _w01_16); + __m256i _sh14_05 = _mm256_mulhi_epi16(_val32_16, _w01_16); + __m256i _sl06_17 = _mm256_mullo_epi16(_val23_16, _w23_16); + __m256i _sh06_17 = _mm256_mulhi_epi16(_val23_16, _w23_16); + __m256i _sl16_07 = _mm256_mullo_epi16(_val32_16, _w23_16); + __m256i _sh16_07 = _mm256_mulhi_epi16(_val32_16, _w23_16); + + _sum04_15 = _mm256_add_epi32(_sum04_15, _mm256_unpacklo_epi16(_sl04_15, _sh04_15)); + _sum14_05 = _mm256_add_epi32(_sum14_05, _mm256_unpacklo_epi16(_sl14_05, _sh14_05)); + _sum06_17 = _mm256_add_epi32(_sum06_17, _mm256_unpacklo_epi16(_sl06_17, _sh06_17)); + _sum16_07 = _mm256_add_epi32(_sum16_07, _mm256_unpacklo_epi16(_sl16_07, _sh16_07)); + _sum04_15 = _mm256_add_epi32(_sum04_15, _mm256_unpackhi_epi16(_sl04_15, _sh04_15)); + _sum14_05 = _mm256_add_epi32(_sum14_05, _mm256_unpackhi_epi16(_sl14_05, _sh14_05)); + _sum06_17 = _mm256_add_epi32(_sum06_17, _mm256_unpackhi_epi16(_sl06_17, _sh06_17)); + _sum16_07 = _mm256_add_epi32(_sum16_07, _mm256_unpackhi_epi16(_sl16_07, _sh16_07)); + + tmpptr += 32; + kptr0 += 32; + } + + // transpose 4x8 + { + __m256i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm256_unpacklo_epi32(_sum00_11, _sum10_01); + _tmp1 = _mm256_unpacklo_epi32(_sum02_13, _sum12_03); + _tmp2 = _mm256_unpackhi_epi32(_sum00_11, _sum10_01); + _tmp3 = _mm256_unpackhi_epi32(_sum02_13, _sum12_03); + _sum00_11 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _sum10_01 = _mm256_unpackhi_epi64(_tmp0, _tmp1); + _sum02_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); + _sum12_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); + } + { + __m256i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm256_unpacklo_epi32(_sum04_15, _sum14_05); + _tmp1 = _mm256_unpacklo_epi32(_sum06_17, _sum16_07); + _tmp2 = _mm256_unpackhi_epi32(_sum04_15, _sum14_05); + _tmp3 = _mm256_unpackhi_epi32(_sum06_17, _sum16_07); + _sum04_15 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _sum14_05 = _mm256_unpackhi_epi64(_tmp0, _tmp1); + _sum06_17 = _mm256_unpacklo_epi64(_tmp2, _tmp3); + _sum16_07 = _mm256_unpackhi_epi64(_tmp2, _tmp3); + } + + _sum00_11 = _mm256_add_epi32(_sum00_11, _sum10_01); + _sum02_13 = _mm256_add_epi32(_sum02_13, _sum12_03); + _sum00_11 = _mm256_add_epi32(_sum00_11, _sum02_13); + + _sum04_15 = _mm256_add_epi32(_sum04_15, _sum14_05); + _sum06_17 = _mm256_add_epi32(_sum06_17, _sum16_07); + _sum04_15 = _mm256_add_epi32(_sum04_15, _sum06_17); + + __m256i _perm_mask = _mm256_set_epi32(6, 3, 4, 1, 7, 2, 5, 0); + _sum00_11 = _mm256_permutevar8x32_epi32(_sum00_11, _perm_mask); + _sum04_15 = _mm256_permutevar8x32_epi32(_sum04_15, _perm_mask); + + _mm256_storeu_si256((__m256i*)outptr0, _sum00_11); + _mm256_storeu_si256((__m256i*)(outptr0 + 8), _sum04_15); + outptr0 += 16; + } +#endif for (; i + 1 < size; i += 2) { +#if __AVX2__ + const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2); +#else const signed char* tmpptr = tmp.channel(i / 2); +#endif const signed char* kptr0 = kernel.channel(p); int nn = inch * maxk; // inch always > 0 +#if __AVX2__ + __m256i _sum00_11 = _mm256_setzero_si256(); + __m256i _sum10_01 = _mm256_setzero_si256(); + __m256i _sum02_13 = _mm256_setzero_si256(); + __m256i _sum12_03 = _mm256_setzero_si256(); +#else __m128i _sum00 = _mm_setzero_si128(); __m128i _sum01 = _mm_setzero_si128(); __m128i _sum02 = _mm_setzero_si128(); @@ -95,10 +267,40 @@ static void im2col_sgemm_pack8to4_int8_sse(const Mat& bottom_im2col, Mat& top_bl __m128i _sum11 = _mm_setzero_si128(); __m128i _sum12 = _mm_setzero_si128(); __m128i _sum13 = _mm_setzero_si128(); +#endif int j = 0; for (; j < nn; j++) { +#if __AVX2__ + __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr); + __m256i _val01_16 = _mm256_cvtepi8_epi16(_val01); + + __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0); + __m128i _w23 = _mm_loadu_si128((const __m128i*)(kptr0 + 16)); + __m256i _w01_16 = _mm256_cvtepi8_epi16(_w01); + __m256i _w23_16 = _mm256_cvtepi8_epi16(_w23); + + __m256i _val10_16 = _mm256_permute4x64_epi64(_val01_16, 78); + + __m256i _sl00_11 = _mm256_mullo_epi16(_val01_16, _w01_16); + __m256i _sh00_11 = _mm256_mulhi_epi16(_val01_16, _w01_16); + __m256i _sl10_01 = _mm256_mullo_epi16(_val10_16, _w01_16); + __m256i _sh10_01 = _mm256_mulhi_epi16(_val10_16, _w01_16); + __m256i _sl02_13 = _mm256_mullo_epi16(_val01_16, _w23_16); + __m256i _sh02_13 = _mm256_mulhi_epi16(_val01_16, _w23_16); + __m256i _sl12_03 = _mm256_mullo_epi16(_val10_16, _w23_16); + __m256i _sh12_03 = _mm256_mulhi_epi16(_val10_16, _w23_16); + + _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_unpacklo_epi16(_sl00_11, _sh00_11)); + _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_unpacklo_epi16(_sl10_01, _sh10_01)); + _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_unpacklo_epi16(_sl02_13, _sh02_13)); + _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_unpacklo_epi16(_sl12_03, _sh12_03)); + _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_unpackhi_epi16(_sl00_11, _sh00_11)); + _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_unpackhi_epi16(_sl10_01, _sh10_01)); + _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_unpackhi_epi16(_sl02_13, _sh02_13)); + _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_unpackhi_epi16(_sl12_03, _sh12_03)); +#else // TODO use _mm_cvtepi8_epi16 on sse4.1 __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr); __m128i _extval01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _val01); @@ -148,11 +350,35 @@ static void im2col_sgemm_pack8to4_int8_sse(const Mat& bottom_im2col, Mat& top_bl _sum11 = _mm_add_epi32(_sum11, _mm_unpackhi_epi16(_sl11, _sh11)); _sum12 = _mm_add_epi32(_sum12, _mm_unpackhi_epi16(_sl12, _sh12)); _sum13 = _mm_add_epi32(_sum13, _mm_unpackhi_epi16(_sl13, _sh13)); +#endif tmpptr += 16; kptr0 += 32; } +#if __AVX2__ + // transpose 4x8 + { + __m256i _tmp0, _tmp1, _tmp2, _tmp3; + _tmp0 = _mm256_unpacklo_epi32(_sum00_11, _sum10_01); + _tmp1 = _mm256_unpacklo_epi32(_sum02_13, _sum12_03); + _tmp2 = _mm256_unpackhi_epi32(_sum00_11, _sum10_01); + _tmp3 = _mm256_unpackhi_epi32(_sum02_13, _sum12_03); + _sum00_11 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _sum10_01 = _mm256_unpackhi_epi64(_tmp0, _tmp1); + _sum02_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); + _sum12_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); + } + + _sum00_11 = _mm256_add_epi32(_sum00_11, _sum10_01); + _sum02_13 = _mm256_add_epi32(_sum02_13, _sum12_03); + _sum00_11 = _mm256_add_epi32(_sum00_11, _sum02_13); + + __m256i _perm_mask = _mm256_set_epi32(6, 3, 4, 1, 7, 2, 5, 0); + _sum00_11 = _mm256_permutevar8x32_epi32(_sum00_11, _perm_mask); + + _mm256_storeu_si256((__m256i*)outptr0, _sum00_11); +#else // transpose 4x4 { __m128i _tmp0, _tmp1, _tmp2, _tmp3; @@ -187,11 +413,16 @@ static void im2col_sgemm_pack8to4_int8_sse(const Mat& bottom_im2col, Mat& top_bl _mm_storeu_si128((__m128i*)outptr0, _sum00); _mm_storeu_si128((__m128i*)(outptr0 + 4), _sum10); +#endif outptr0 += 8; } for (; i < size; i++) { +#if __AVX2__ + const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2); +#else const signed char* tmpptr = tmp.channel(i / 2 + i % 2); +#endif const signed char* kptr0 = kernel.channel(p); int nn = inch * maxk; // inch always > 0 diff --git a/src/layer/x86/convolution_x86.cpp b/src/layer/x86/convolution_x86.cpp index e978f4b81..52a145628 100644 --- a/src/layer/x86/convolution_x86.cpp +++ b/src/layer/x86/convolution_x86.cpp @@ -49,7 +49,13 @@ namespace ncnn { #include "convolution_pack1to4_int8.h" #include "convolution_pack8to1_int8.h" #include "convolution_sgemm_pack8to4_int8.h" +#include "convolution_sgemm_pack1to4_int8.h" +#include "convolution_sgemm_pack8to1_int8.h" #include "convolution_1x1_pack8to4_int8.h" +#include "convolution_1x1_pack1to4_int8.h" +#include "convolution_1x1_pack8to1_int8.h" +#include "convolution_3x3_pack1to4_int8.h" +#include "convolution_7x7_pack1to4_int8.h" #endif // NCNN_INT8 #if __AVX__ @@ -1133,36 +1139,9 @@ int Convolution_x86::forward(const std::vector& bottom_blobs, std::vector= 16 && num_output >= 16) - { - conv3x3s1_winograd23_transform_kernel_int8_sse(weight_data, weight_data_3x3_winograd23_int8, num_input, num_output, opt); - // conv3x3s1_winograd43_transform_kernel_int8_sse(weight_data, weight_data_3x3_winograd23_int8, num_input, num_output, opt); - } - else - { - // TODO offline transform weight - } - - return 0; - } // src = kw-kh-inch-outch // dst = pa-pb-kw-kh-inch/pa-outch/pb @@ -1181,11 +1160,11 @@ int Convolution_x86::create_pipeline_int8_x86(const Option& opt) for (int k = 0; k < maxk; k++) { - for (int j = 0; j < out_elempack; j++) + for (int i = 0; i < out_elempack; i++) { - for (int i = 0; i < elempack; i++) + for (int j = 0; j < elempack; j++) { - const signed char* k00 = weight_data_r2.channel(q + j).row(p + i); + const signed char* k00 = weight_data_r2.channel(q + i).row(p + j); g00[0] = k00[k]; @@ -1196,25 +1175,121 @@ int Convolution_x86::create_pipeline_int8_x86(const Option& opt) } } } +} + +int Convolution_x86::create_pipeline_int8_x86(const Option& opt) +{ + const int maxk = kernel_w * kernel_h; + const int num_input = weight_data_size / maxk / num_output; + + int elempack = 1; + int out_elempack = 1; +#if __SSE2__ + if (opt.use_packing_layout) + { + elempack = num_input % 8 == 0 ? 8 : 1; + out_elempack = num_output % 4 == 0 ? 4 : 1; + } +#endif // __SSE2__ #if __SSE2__ if (elempack == 8 && out_elempack == 4) { if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - convolution_im2col_sgemm_transform_kernel_pack8to4_int8_sse(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); + convolution_im2col_sgemm_transform_kernel_pack8to4_int8_sse(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); } else if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) { - convolution_im2col_sgemm_transform_kernel_pack8to4_int8_sse(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); + convolution_im2col_sgemm_transform_kernel_pack8to4_int8_sse(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); } else if (opt.use_sgemm_convolution) { - convolution_im2col_sgemm_transform_kernel_pack8to4_int8_sse(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h); + convolution_im2col_sgemm_transform_kernel_pack8to4_int8_sse(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); + } + else + { + convolution_transform_kernel_packed_int8_sse(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h, elempack, out_elempack); + } + } + + if (elempack == 1 && out_elempack == 4) + { + if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { + convolution_im2col_sgemm_transform_kernel_pack1to4_int8_sse(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); + } + else if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { + convolution_im2col_sgemm_transform_kernel_pack1to4_int8_sse(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); + } + else if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { + convolution_im2col_sgemm_transform_kernel_pack1to4_int8_sse(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); + } + else if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { + convolution_im2col_sgemm_transform_kernel_pack1to4_int8_sse(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); + } + else if (kernel_w == 7 && kernel_h == 7 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { + convolution_im2col_sgemm_transform_kernel_pack1to4_int8_sse(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); + } + else if (opt.use_sgemm_convolution) // TODO better condition && num_input >= 8 && num_output >= 8) + { + convolution_im2col_sgemm_transform_kernel_pack1to4_int8_sse(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); + } + else + { + convolution_transform_kernel_packed_int8_sse(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h, elempack, out_elempack); + } + } + + if (elempack == 8 && out_elempack == 1) + { + if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { + convolution_im2col_sgemm_transform_kernel_pack8to1_int8_sse(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); + } + else if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { + convolution_im2col_sgemm_transform_kernel_pack8to1_int8_sse(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); + } + else if (opt.use_sgemm_convolution) // TODO better condition && num_input >= 8 && num_output >= 8) + { + convolution_im2col_sgemm_transform_kernel_pack8to1_int8_sse(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); + } + else + { + convolution_transform_kernel_packed_int8_sse(weight_data, weight_data_int8, num_input, num_output, kernel_w, kernel_h, elempack, out_elempack); } } #endif // __SSE2__ + if (elempack == 1 && out_elempack == 1) + { + if (opt.use_winograd_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1 && num_input >= 16 && num_output >= 16) + { + conv3x3s1_winograd23_transform_kernel_int8_sse(weight_data, weight_data_3x3_winograd23_int8, num_input, num_output, opt); + // conv3x3s1_winograd43_transform_kernel_int8_sse(weight_data, weight_data_3x3_winograd23_int8, num_input, num_output, opt); + } + + if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { + convolution_im2col_sgemm_transform_kernel_int8_sse(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); + } + else if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { + convolution_im2col_sgemm_transform_kernel_int8_sse(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); + } + else if (opt.use_sgemm_convolution) + { + convolution_im2col_sgemm_transform_kernel_int8_sse(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); + } + + return 0; + } + return 0; } @@ -1287,15 +1362,15 @@ int Convolution_x86::forward_int8_x86(const Mat& bottom_blob, Mat& top_blob, con { if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - conv1x1s1_sgemm_pack8to4_int8_sse(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); + conv1x1s1_sgemm_pack8to4_int8_sse(bottom_blob_bordered, top_blob_int32, weight_sgemm_data, opt); } else if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) { - conv1x1s2_pack8to4_int8_sse(bottom_blob_bordered, top_blob_int32, weight_data_int8, opt); + conv1x1s2_sgemm_pack8to4_int8_sse(bottom_blob_bordered, top_blob_int32, weight_sgemm_data, opt); } else if (opt.use_sgemm_convolution) { - convolution_im2col_sgemm_pack8to4_int8_sse(bottom_blob_bordered, top_blob_int32, weight_data_int8, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); + convolution_im2col_sgemm_pack8to4_int8_sse(bottom_blob_bordered, top_blob_int32, weight_sgemm_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); } else { @@ -1332,7 +1407,34 @@ int Convolution_x86::forward_int8_x86(const Mat& bottom_blob, Mat& top_blob, con if (elempack == 1 && out_elempack_int32 == 4) { - convolution_pack1to4_int8_sse(bottom_blob_bordered, top_blob_int32, weight_data_int8, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); + if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { + conv1x1s1_sgemm_pack1to4_int8_sse(bottom_blob_bordered, top_blob_int32, weight_sgemm_data, opt); + } + else if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { + conv1x1s2_sgemm_pack1to4_int8_sse(bottom_blob_bordered, top_blob_int32, weight_sgemm_data, opt); + } + else if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { + conv3x3s1_pack1to4_int8_sse(bottom_blob_bordered, top_blob_int32, weight_sgemm_data, opt); + } + else if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { + conv3x3s2_pack1to4_int8_sse(bottom_blob_bordered, top_blob_int32, weight_sgemm_data, opt); + } + else if (kernel_w == 7 && kernel_h == 7 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { + conv7x7s2_pack1to4_int8_sse(bottom_blob_bordered, top_blob_int32, weight_sgemm_data, opt); + } + else if (opt.use_sgemm_convolution) // TODO better condition && num_input >= 8 && num_output >= 8) + { + convolution_im2col_sgemm_pack1to4_int8_sse(bottom_blob_bordered, top_blob_int32, weight_sgemm_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); + } + else + { + convolution_pack1to4_int8_sse(bottom_blob_bordered, top_blob_int32, weight_data_int8, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); + } Mat scale_in_data(num_output); for (int p = 0; p < num_output; p++) @@ -1364,7 +1466,22 @@ int Convolution_x86::forward_int8_x86(const Mat& bottom_blob, Mat& top_blob, con if (elempack == 8 && out_elempack_int32 == 1) { - convolution_pack8to1_int8_sse(bottom_blob_bordered, top_blob_int32, weight_data_int8, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); + if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { + conv1x1s1_sgemm_pack8to1_int8_sse(bottom_blob_bordered, top_blob_int32, weight_sgemm_data, opt); + } + else if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { + conv1x1s2_sgemm_pack8to1_int8_sse(bottom_blob_bordered, top_blob_int32, weight_sgemm_data, opt); + } + else if (opt.use_sgemm_convolution) // TODO better condition && num_input >= 8 && num_output >= 8) + { + convolution_im2col_sgemm_pack8to1_int8_sse(bottom_blob_bordered, top_blob_int32, weight_sgemm_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); + } + else + { + convolution_pack8to1_int8_sse(bottom_blob_bordered, top_blob_int32, weight_data_int8, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); + } Mat scale_in_data(num_output); for (int p = 0; p < num_output; p++) @@ -1397,111 +1514,53 @@ int Convolution_x86::forward_int8_x86(const Mat& bottom_blob, Mat& top_blob, con if (elempack == 1 && out_elempack_int32 == 1) { - if (opt.use_winograd_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1 && num_input >= 16 && num_output >= 16) + if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { + conv1x1s1_sgemm_int8_sse(bottom_blob_bordered, top_blob_int32, weight_sgemm_data, opt); + } + else if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { + conv1x1s2_sgemm_int8_sse(bottom_blob_bordered, top_blob_int32, weight_sgemm_data, opt); + } + else if (opt.use_winograd_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1 && num_input >= 16 && num_output >= 16) { conv3x3s1_winograd23_int8_sse(bottom_blob_bordered, top_blob_int32, weight_data_3x3_winograd23_int8, opt); // conv3x3s1_winograd43_int8_sse(bottom_blob_bordered, top_blob_int32, weight_data_3x3_winograd23_int8, opt); - - Mat scale_in_data(num_output); - for (int p = 0; p < num_output; p++) - { - // requantize and relu - float scale_in; - if (weight_data_int8_scales[p] == 0) - scale_in = 0; - else - scale_in = 1.f / (bottom_blob_int8_scales[0] * weight_data_int8_scales[p]); - - scale_in_data[p] = scale_in; - } - - if (use_int8_requantize) - { - requantize_from_int32_to_int8(top_blob_int32, top_blob, scale_in_data, top_blob_int8_scales, bias_data, activation_type, activation_params, opt); - } - else - { - dequantize_from_int32(top_blob_int32, top_blob, scale_in_data, bias_data, opt); - - if (activation) - { - activation->forward_inplace(top_blob, opt); - } - } } - else if (opt.use_sgemm_convolution && dilation_w == 1 && dilation_h == 1 && (activation_type == 0 || activation_type == 1)) + else if (opt.use_sgemm_convolution) { - if (use_int8_requantize) - { - std::vector requantize_scales; - for (int p = 0; p < num_output; p++) - { - float scale_in; - if (weight_data_int8_scales[p] == 0) - scale_in = 0; - else - scale_in = 1.f / (bottom_blob_int8_scales[0] * weight_data_int8_scales[p]); - - float scale_out = top_blob_int8_scales[0]; - - requantize_scales.push_back(scale_in); - requantize_scales.push_back(scale_out); - } + convolution_im2col_sgemm_int8_sse(bottom_blob_bordered, top_blob_int32, weight_sgemm_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); + } + else + { + // convolution_int8(bottom_blob_bordered, top_blob_int32, weight_data_int8, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); + convolution_int8(bottom_blob_bordered, top_blob_int32, weight_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); + } - conv_im2col_sgemm_int8_requant_sse(bottom_blob_bordered, top_blob, weight_data, kernel_w, kernel_h, stride_w, stride_h, bias_data, requantize_scales, opt); - } + Mat scale_in_data(num_output); + for (int p = 0; p < num_output; p++) + { + // requantize and relu + float scale_in; + if (weight_data_int8_scales[p] == 0) + scale_in = 0; else - { - std::vector dequantize_scales; - for (int p = 0; p < num_output; p++) - { - float scale_in; - if (weight_data_int8_scales[p] == 0) - scale_in = 0; - else - scale_in = 1.f / (bottom_blob_int8_scales[0] * weight_data_int8_scales[p]); - - dequantize_scales.push_back(scale_in); - } + scale_in = 1.f / (bottom_blob_int8_scales[0] * weight_data_int8_scales[p]); - conv_im2col_sgemm_int8_dequant_sse(bottom_blob_bordered, top_blob, weight_data, kernel_w, kernel_h, stride_w, stride_h, bias_data, dequantize_scales, opt); - } + scale_in_data[p] = scale_in; + } - if (activation) - { - activation->forward_inplace(top_blob, opt); - } + if (use_int8_requantize) + { + requantize_from_int32_to_int8(top_blob_int32, top_blob, scale_in_data, top_blob_int8_scales, bias_data, activation_type, activation_params, opt); } else { - // convolution_int8(bottom_blob_bordered, top_blob_int32, weight_data_int8, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); - convolution_int8(bottom_blob_bordered, top_blob_int32, weight_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); - - Mat scale_in_data(num_output); - for (int p = 0; p < num_output; p++) - { - // requantize and relu - float scale_in; - if (weight_data_int8_scales[p] == 0) - scale_in = 0; - else - scale_in = 1.f / (bottom_blob_int8_scales[0] * weight_data_int8_scales[p]); - - scale_in_data[p] = scale_in; - } + dequantize_from_int32(top_blob_int32, top_blob, scale_in_data, bias_data, opt); - if (use_int8_requantize) - { - requantize_from_int32_to_int8(top_blob_int32, top_blob, scale_in_data, top_blob_int8_scales, bias_data, activation_type, activation_params, opt); - } - else + if (activation) { - dequantize_from_int32(top_blob_int32, top_blob, scale_in_data, bias_data, opt); - - if (activation) - { - activation->forward_inplace(top_blob, opt); - } + activation->forward_inplace(top_blob, opt); } } } diff --git a/src/layer/x86/convolution_x86.h b/src/layer/x86/convolution_x86.h index cbc0ce0fb..89615420b 100644 --- a/src/layer/x86/convolution_x86.h +++ b/src/layer/x86/convolution_x86.h @@ -41,13 +41,16 @@ protected: public: Layer* activation; - Mat weight_data_packed; + Mat weight_sgemm_data; Mat weight_data_3x3_winograd23; Mat weight_data_3x3_winograd63; // forwardDilation Layer* convolution_dilation1; + // pack4/8 + Mat weight_data_packed; + #if NCNN_INT8 // int8 Mat weight_data_int8; diff --git a/src/layer/x86/x86_usability.h b/src/layer/x86/x86_usability.h index c59b4bbc3..f82a33aef 100644 --- a/src/layer/x86/x86_usability.h +++ b/src/layer/x86/x86_usability.h @@ -15,6 +15,13 @@ #ifndef X86_USABILITY_H #define X86_USABILITY_H +#if __SSE2__ +#include +#if __AVX__ +#include +#endif +#endif // __SSE2__ + static NCNN_FORCEINLINE signed char float2int8(float v) { int int32 = (int)round(v); @@ -24,8 +31,6 @@ static NCNN_FORCEINLINE signed char float2int8(float v) } #if __SSE2__ -#include - static NCNN_FORCEINLINE float _mm_reduce_add_ps(__m128 x128) { const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); @@ -33,6 +38,15 @@ static NCNN_FORCEINLINE float _mm_reduce_add_ps(__m128 x128) return _mm_cvtss_f32(x32); } +static NCNN_FORCEINLINE int _mm_reduce_add_epi32(__m128i x) +{ + __m128i hi64 = _mm_unpackhi_epi64(x, x); + __m128i sum64 = _mm_add_epi32(hi64, x); + __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + __m128i sum32 = _mm_add_epi32(sum64, hi32); + return _mm_cvtsi128_si32(sum32); +} + static NCNN_FORCEINLINE int32_t float2int8_sse(const __m128& _v0) { // _MM_ROUND_NEAREST round to even @@ -120,7 +134,7 @@ static NCNN_FORCEINLINE __m128i float2int8_sse(const __m128& _v0, const __m128& return _v8; } -#if __SSE2__ + #ifndef __AVX2__ static NCNN_FORCEINLINE __m128 _mm_comp_fmadd_ps(__m128 _a, const __m128 _b, const __m128 _c) @@ -128,9 +142,8 @@ static NCNN_FORCEINLINE __m128 _mm_comp_fmadd_ps(__m128 _a, const __m128 _b, con return _mm_add_ps(_mm_mul_ps(_a, _b), _c); } #endif -#endif + #if __AVX__ -#include #ifndef __AVX2__ static NCNN_FORCEINLINE __m256 _mm256_comp_fmadd_ps(__m256 _a, const __m256 _b, const __m256 _c) { @@ -152,13 +165,14 @@ static NCNN_FORCEINLINE __m256 _mm256_comp_fmadd_ps(__m256 _a, const __m256 _b, return _mm256_fmadd_ps(_a, _b, _c); } #endif -#if __AVX2__ +#if __AVX2__ static NCNN_FORCEINLINE __m256 loadfp16(const unsigned short* ptr) { return _mm256_cvtph_ps(_mm_lddqu_si128((__m128i*)(ptr))); } #endif + static NCNN_FORCEINLINE __m256 _mm256_fmadd_1_ps(__m256 a, __m256 b, float c) { return _mm256_comp_fmadd_ps(b, _mm256_set1_ps(c), a);