From 1133a18ca83162a1a49cba4c75f3a4ecea74fb53 Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 10 Apr 2023 22:34:14 +0800 Subject: [PATCH] x86 and arm optimization for convolution1d packed unified elempack (#4615) --- src/layer/arm/convolution1d_arm.cpp | 482 +--- src/layer/arm/convolution1d_arm.h | 9 +- src/layer/arm/convolution1d_arm_asimdhp.cpp | 663 +---- src/layer/arm/convolution1d_packed.h | 1272 +++++++++ src/layer/arm/convolution1d_packed_bf16s.h | 1311 +++++++++ src/layer/arm/convolution1d_packed_fp16s.h | 1848 +++++++++++++ src/layer/arm/convolution_packed_fp16s.h | 6 +- src/layer/x86/convolution1d_packed.h | 2702 +++++++++++++++++++ src/layer/x86/convolution1d_x86.cpp | 524 +--- src/layer/x86/convolution1d_x86.h | 2 +- src/layer/x86/convolution_packed.h | 324 +-- tests/test_convolution1d.cpp | 72 +- 12 files changed, 7289 insertions(+), 1926 deletions(-) create mode 100644 src/layer/arm/convolution1d_packed.h create mode 100644 src/layer/arm/convolution1d_packed_bf16s.h create mode 100644 src/layer/arm/convolution1d_packed_fp16s.h create mode 100644 src/layer/x86/convolution1d_packed.h diff --git a/src/layer/arm/convolution1d_arm.cpp b/src/layer/arm/convolution1d_arm.cpp index 0314f9260..ab480aec6 100644 --- a/src/layer/arm/convolution1d_arm.cpp +++ b/src/layer/arm/convolution1d_arm.cpp @@ -26,6 +26,11 @@ namespace ncnn { +#include "convolution1d_packed.h" +#if NCNN_BF16 +#include "convolution1d_packed_bf16s.h" +#endif // NCNN_BF16 + Convolution1D_arm::Convolution1D_arm() { #if __ARM_NEON @@ -61,47 +66,7 @@ int Convolution1D_arm::create_pipeline(const Option& opt) const int num_input = weight_data_size / kernel_w / num_output; - int elempack = 1; - int out_elempack = 1; - -#if __ARM_NEON - if (opt.use_packing_layout) - { - elempack = num_input % 4 == 0 ? 4 : 1; - out_elempack = num_output % 4 == 0 ? 4 : 1; - } -#endif - - // src = kw-inch-outch - // dst = pb-pa-kw-inch/pa-outch/pb - { - Mat weight_data_r2 = weight_data.reshape(kernel_w, num_input, num_output); - - weight_data_packed.create(kernel_w, num_input / elempack, num_output / out_elempack, (size_t)4u * elempack * out_elempack, elempack * out_elempack); - - for (int q = 0; q + (out_elempack - 1) < num_output; q += out_elempack) - { - float* g00 = weight_data_packed.channel(q / out_elempack); - - for (int p = 0; p + (elempack - 1) < num_input; p += elempack) - { - for (int k = 0; k < kernel_w; k++) - { - for (int i = 0; i < elempack; i++) - { - for (int j = 0; j < out_elempack; j++) - { - const float* k00 = weight_data_r2.channel(q + j).row(p + i); - - g00[0] = k00[k]; - - g00++; - } - } - } - } - } - } + convolution1d_transform_kernel_packed(weight_data, weight_data_tm, num_input, num_output, kernel_w); return 0; } @@ -131,7 +96,6 @@ int Convolution1D_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Opti #endif int w = bottom_blob.w; - int h = bottom_blob.h; size_t elemsize = bottom_blob.elemsize; int elempack = bottom_blob.elempack; @@ -143,7 +107,6 @@ int Convolution1D_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Opti return -100; w = bottom_blob_bordered.w; - h = bottom_blob_bordered.h; int out_elempack = 1; #if __ARM_NEON @@ -161,199 +124,7 @@ int Convolution1D_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Opti if (top_blob.empty()) return -100; -#if __ARM_NEON - if (elempack == 4 && out_elempack == 4) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - float* outptr = top_blob.row(p); - - for (int j = 0; j < outw; j++) - { - float32x4_t _sum = vdupq_n_f32(0.f); - - if (bias_term) - { - _sum = vld1q_f32((const float*)bias_data + p * 4); - } - - const float* kptr = weight_data_packed.channel(p); - - for (int q = 0; q < h; q++) - { - const float* sptr = bottom_blob_bordered.row(q) + j * stride_w * 4; - - for (int k = 0; k < kernel_w; k++) - { - float32x4_t _val = vld1q_f32(sptr); - - float32x4_t _w0 = vld1q_f32(kptr); - float32x4_t _w1 = vld1q_f32(kptr + 4); - float32x4_t _w2 = vld1q_f32(kptr + 8); - float32x4_t _w3 = vld1q_f32(kptr + 12); - -#if __aarch64__ - _sum = vmlaq_laneq_f32(_sum, _w0, _val, 0); - _sum = vmlaq_laneq_f32(_sum, _w1, _val, 1); - _sum = vmlaq_laneq_f32(_sum, _w2, _val, 2); - _sum = vmlaq_laneq_f32(_sum, _w3, _val, 3); -#else - _sum = vmlaq_lane_f32(_sum, _w0, vget_low_f32(_val), 0); - _sum = vmlaq_lane_f32(_sum, _w1, vget_low_f32(_val), 1); - _sum = vmlaq_lane_f32(_sum, _w2, vget_high_f32(_val), 0); - _sum = vmlaq_lane_f32(_sum, _w3, vget_high_f32(_val), 1); -#endif - - sptr += dilation_w * 4; - kptr += 16; - } - } - - _sum = activation_ps(_sum, activation_type, activation_params); - - vst1q_f32(outptr, _sum); - outptr += 4; - } - } - } - } - - if (elempack == 1 && out_elempack == 4) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - float* outptr = top_blob.row(p); - - for (int j = 0; j < outw; j++) - { - float32x4_t _sum = vdupq_n_f32(0.f); - - if (bias_term) - { - _sum = vld1q_f32((const float*)bias_data + p * 4); - } - - const float* kptr = weight_data_packed.channel(p); - - for (int q = 0; q < h; q++) - { - const float* sptr = bottom_blob_bordered.row(q) + j * stride_w; - - for (int k = 0; k < kernel_w; k++) - { - float32x4_t _val = vdupq_n_f32(sptr[0]); - float32x4_t _w = vld1q_f32(kptr); - _sum = vmlaq_f32(_sum, _val, _w); - - sptr += dilation_w; - kptr += 4; - } - } - - _sum = activation_ps(_sum, activation_type, activation_params); - - vst1q_f32(outptr, _sum); - outptr += 4; - } - } - } - } - - if (elempack == 4 && out_elempack == 1) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - float* outptr = top_blob.row(p); - - for (int j = 0; j < outw; j++) - { - float sum = 0.f; - - if (bias_term) - { - sum = bias_data[p]; - } - - const float* kptr = weight_data_packed.channel(p); - - for (int q = 0; q < h; q++) - { - const float* sptr = bottom_blob_bordered.row(q) + j * stride_w * 4; - - for (int k = 0; k < kernel_w; k++) // 29.23 - { - float32x4_t _val = vld1q_f32(sptr); - float32x4_t _w = vld1q_f32(kptr); - float32x4_t _s4 = vmulq_f32(_val, _w); -#if __aarch64__ - sum += vaddvq_f32(_s4); // dot -#else - float32x2_t _ss = vadd_f32(vget_low_f32(_s4), vget_high_f32(_s4)); - _ss = vpadd_f32(_ss, _ss); - sum += vget_lane_f32(_ss, 0); -#endif - - sptr += dilation_w * 4; - kptr += 4; - } - } - - sum = activation_ss(sum, activation_type, activation_params); - - outptr[j] = sum; - } - } - } - } -#endif // __ARM_NEON - - if (elempack == 1 && out_elempack == 1) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - float* outptr = top_blob.row(p); - - for (int j = 0; j < outw; j++) - { - float sum = 0.f; - - if (bias_term) - { - sum = bias_data[p]; - } - - const float* kptr = (const float*)weight_data + kernel_w * h * p; - - for (int q = 0; q < h; q++) - { - const float* sptr = bottom_blob_bordered.row(q) + j * stride_w; - - for (int k = 0; k < kernel_w; k++) - { - float val = sptr[0]; - float wt = kptr[0]; - sum += val * wt; - - sptr += dilation_w; - kptr += 1; - } - } - - sum = activation_ss(sum, activation_type, activation_params); - - outptr[j] = sum; - } - } - } - } + convolution1d_packed(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, kernel_w, dilation_w, stride_w, activation_type, activation_params, opt); return 0; } @@ -460,50 +231,11 @@ int Convolution1D_arm::forward(const std::vector& bottom_blobs, std::vector } #if NCNN_BF16 -int Convolution1D_arm::create_pipeline_bf16s(const Option& opt) +int Convolution1D_arm::create_pipeline_bf16s(const Option& /*opt*/) { const int num_input = weight_data_size / kernel_w / num_output; - int elempack = 1; - int out_elempack = 1; -#if __ARM_NEON - if (opt.use_packing_layout) - { - elempack = num_input % 4 == 0 ? 4 : 1; - out_elempack = num_output % 4 == 0 ? 4 : 1; - } -#endif - - // src = kw-inch-outch - // dst = pb-pa-kw-inch/pa-outch/pb - { - Mat weight_data_r2 = weight_data.reshape(kernel_w, num_input, num_output); - - weight_data_bf16.create(kernel_w, num_input / elempack, num_output / out_elempack, (size_t)2u * elempack * out_elempack, elempack * out_elempack); - - for (int q = 0; q + (out_elempack - 1) < num_output; q += out_elempack) - { - unsigned short* g00 = weight_data_bf16.channel(q / out_elempack); - - for (int p = 0; p + (elempack - 1) < num_input; p += elempack) - { - for (int k = 0; k < kernel_w; k++) - { - for (int i = 0; i < elempack; i++) - { - for (int j = 0; j < out_elempack; j++) - { - const float* k00 = weight_data_r2.channel(q + j).row(p + i); - - g00[0] = float32_to_bfloat16(k00[k]); - - g00++; - } - } - } - } - } - } + convolution1d_transform_kernel_packed_bf16s(weight_data, weight_data_tm, num_input, num_output, kernel_w); return 0; } @@ -511,7 +243,6 @@ int Convolution1D_arm::create_pipeline_bf16s(const Option& opt) int Convolution1D_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { int w = bottom_blob.w; - int h = bottom_blob.h; size_t elemsize = bottom_blob.elemsize; int elempack = bottom_blob.elempack; @@ -523,7 +254,6 @@ int Convolution1D_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, cons return -100; w = bottom_blob_bordered.w; - h = bottom_blob_bordered.h; int out_elempack = 1; #if __ARM_NEON @@ -541,199 +271,7 @@ int Convolution1D_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, cons if (top_blob.empty()) return -100; -#if __ARM_NEON - if (elempack == 4 && out_elempack == 4) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - unsigned short* outptr = top_blob.row(p); - - for (int j = 0; j < outw; j++) - { - float32x4_t _sum = vdupq_n_f32(0.f); - - if (bias_term) - { - _sum = vld1q_f32((const float*)bias_data + p * 4); - } - - const unsigned short* kptr = weight_data_bf16.channel(p); - - for (int q = 0; q < h; q++) - { - const unsigned short* sptr = bottom_blob_bordered.row(q) + j * stride_w * 4; - - for (int k = 0; k < kernel_w; k++) - { - float32x4_t _val = bfloat2float(vld1_u16(sptr)); - - float32x4_t _w0 = bfloat2float(vld1_u16(kptr)); - float32x4_t _w1 = bfloat2float(vld1_u16(kptr + 4)); - float32x4_t _w2 = bfloat2float(vld1_u16(kptr + 8)); - float32x4_t _w3 = bfloat2float(vld1_u16(kptr + 12)); - -#if __aarch64__ - _sum = vmlaq_laneq_f32(_sum, _w0, _val, 0); - _sum = vmlaq_laneq_f32(_sum, _w1, _val, 1); - _sum = vmlaq_laneq_f32(_sum, _w2, _val, 2); - _sum = vmlaq_laneq_f32(_sum, _w3, _val, 3); -#else - _sum = vmlaq_lane_f32(_sum, _w0, vget_low_f32(_val), 0); - _sum = vmlaq_lane_f32(_sum, _w1, vget_low_f32(_val), 1); - _sum = vmlaq_lane_f32(_sum, _w2, vget_high_f32(_val), 0); - _sum = vmlaq_lane_f32(_sum, _w3, vget_high_f32(_val), 1); -#endif - - sptr += dilation_w * 4; - kptr += 16; - } - } - - _sum = activation_ps(_sum, activation_type, activation_params); - - vst1_u16(outptr, float2bfloat(_sum)); - outptr += 4; - } - } - } - } - - if (elempack == 1 && out_elempack == 4) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - unsigned short* outptr = top_blob.row(p); - - for (int j = 0; j < outw; j++) - { - float32x4_t _sum = vdupq_n_f32(0.f); - - if (bias_term) - { - _sum = vld1q_f32((const float*)bias_data + p * 4); - } - - const unsigned short* kptr = weight_data_bf16.channel(p); - - for (int q = 0; q < h; q++) - { - const unsigned short* sptr = bottom_blob_bordered.row(q) + j * stride_w; - - for (int k = 0; k < kernel_w; k++) - { - float32x4_t _val = vdupq_n_f32(bfloat16_to_float32(sptr[0])); - float32x4_t _w = bfloat2float(vld1_u16(kptr)); - _sum = vmlaq_f32(_sum, _val, _w); - - sptr += dilation_w; - kptr += 4; - } - } - - _sum = activation_ps(_sum, activation_type, activation_params); - - vst1_u16(outptr, float2bfloat(_sum)); - outptr += 4; - } - } - } - } - - if (elempack == 4 && out_elempack == 1) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - unsigned short* outptr = top_blob.row(p); - - for (int j = 0; j < outw; j++) - { - float sum = 0.f; - - if (bias_term) - { - sum = bias_data[p]; - } - - const unsigned short* kptr = weight_data_bf16.channel(p); - - for (int q = 0; q < h; q++) - { - const unsigned short* sptr = bottom_blob_bordered.row(q) + j * stride_w * 4; - - for (int k = 0; k < kernel_w; k++) - { - float32x4_t _val = bfloat2float(vld1_u16(sptr)); - float32x4_t _w = bfloat2float(vld1_u16(kptr)); - float32x4_t _s4 = vmulq_f32(_val, _w); -#if __aarch64__ - sum += vaddvq_f32(_s4); // dot -#else - float32x2_t _ss = vadd_f32(vget_low_f32(_s4), vget_high_f32(_s4)); - _ss = vpadd_f32(_ss, _ss); - sum += vget_lane_f32(_ss, 0); -#endif - - sptr += dilation_w * 4; - kptr += 4; - } - } - - sum = activation_ss(sum, activation_type, activation_params); - - outptr[j] = float32_to_bfloat16(sum); - } - } - } - } -#endif // __ARM_NEON - - if (elempack == 1 && out_elempack == 1) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - unsigned short* outptr = top_blob.row(p); - - for (int j = 0; j < outw; j++) - { - float sum = 0.f; - - if (bias_term) - { - sum = bias_data[p]; - } - - const unsigned short* kptr = weight_data_bf16.channel(p); - - for (int q = 0; q < h; q++) - { - const unsigned short* sptr = bottom_blob_bordered.row(q) + j * stride_w; - - for (int k = 0; k < kernel_w; k++) - { - float val = bfloat16_to_float32(sptr[0]); - float wt = bfloat16_to_float32(kptr[0]); - sum += val * wt; - - sptr += dilation_w; - kptr += 1; - } - } - - sum = activation_ss(sum, activation_type, activation_params); - - outptr[j] = float32_to_bfloat16(sum); - } - } - } - } + convolution1d_packed_bf16s(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, kernel_w, dilation_w, stride_w, activation_type, activation_params, opt); return 0; } diff --git a/src/layer/arm/convolution1d_arm.h b/src/layer/arm/convolution1d_arm.h index 6d407ddc6..83e0ea838 100644 --- a/src/layer/arm/convolution1d_arm.h +++ b/src/layer/arm/convolution1d_arm.h @@ -43,17 +43,10 @@ protected: #endif public: - // pack4 - Mat weight_data_packed; + Mat weight_data_tm; // fp16 - Mat weight_data_fp16; Mat bias_data_fp16; - -#if NCNN_BF16 - // bf16 - Mat weight_data_bf16; -#endif }; } // namespace ncnn diff --git a/src/layer/arm/convolution1d_arm_asimdhp.cpp b/src/layer/arm/convolution1d_arm_asimdhp.cpp index 0225d26cf..bbbd58830 100644 --- a/src/layer/arm/convolution1d_arm_asimdhp.cpp +++ b/src/layer/arm/convolution1d_arm_asimdhp.cpp @@ -26,49 +26,13 @@ namespace ncnn { #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include "convolution1d_packed_fp16s.h" + int Convolution1D_arm::create_pipeline_fp16s(const Option& opt) { const int num_input = weight_data_size / kernel_w / num_output; - int elempack = 1; - int out_elempack = 1; - - if (opt.use_packing_layout) - { - elempack = opt.use_fp16_arithmetic && num_input % 8 == 0 ? 8 : num_input % 4 == 0 ? 4 : 1; - out_elempack = opt.use_fp16_arithmetic && num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; - } - - // src = kw-inch-outch - // dst = pb-pa-kw-inch/pa-outch/pb - { - Mat weight_data_r2 = weight_data.reshape(kernel_w, num_input, num_output); - - weight_data_fp16.create(kernel_w, num_input / elempack, num_output / out_elempack, (size_t)2u * elempack * out_elempack, elempack * out_elempack); - - for (int q = 0; q + (out_elempack - 1) < num_output; q += out_elempack) - { - __fp16* g00 = weight_data_fp16.channel(q / out_elempack); - - for (int p = 0; p + (elempack - 1) < num_input; p += elempack) - { - for (int k = 0; k < kernel_w; k++) - { - for (int i = 0; i < elempack; i++) - { - for (int j = 0; j < out_elempack; j++) - { - const float* k00 = weight_data_r2.channel(q + j).row(p + i); - - g00[0] = (__fp16)k00[k]; - - g00++; - } - } - } - } - } - } + convolution1d_transform_kernel_packed_fp16s(weight_data, weight_data_tm, num_input, num_output, kernel_w); ncnn::cast_float32_to_float16(bias_data, bias_data_fp16, opt); @@ -78,7 +42,6 @@ int Convolution1D_arm::create_pipeline_fp16s(const Option& opt) int Convolution1D_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { int w = bottom_blob.w; - int h = bottom_blob.h; size_t elemsize = bottom_blob.elemsize; int elempack = bottom_blob.elempack; @@ -90,7 +53,6 @@ int Convolution1D_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, cons return -100; w = bottom_blob_bordered.w; - h = bottom_blob_bordered.h; int out_elempack = (opt.use_packing_layout && num_output % 4 == 0) ? 4 : 1; size_t out_elemsize = elemsize / elempack * out_elempack; @@ -102,185 +64,7 @@ int Convolution1D_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, cons if (top_blob.empty()) return -100; - if (elempack == 4 && out_elempack == 4) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - __fp16* outptr = top_blob.row<__fp16>(p); - - for (int j = 0; j < outw; j++) - { - float32x4_t _sum = vdupq_n_f32(0.f); - - if (bias_term) - { - _sum = vld1q_f32((const float*)bias_data + p * 4); - } - - const __fp16* kptr = weight_data_fp16.channel(p); - - for (int q = 0; q < h; q++) - { - const __fp16* sptr = bottom_blob_bordered.row(q) + j * stride_w * 4; - - for (int k = 0; k < kernel_w; k++) - { - float32x4_t _val = vcvt_f32_f16(vld1_f16(sptr)); - - float32x4_t _w0 = vcvt_f32_f16(vld1_f16(kptr)); - float32x4_t _w1 = vcvt_f32_f16(vld1_f16(kptr + 4)); - float32x4_t _w2 = vcvt_f32_f16(vld1_f16(kptr + 8)); - float32x4_t _w3 = vcvt_f32_f16(vld1_f16(kptr + 12)); - - _sum = vfmaq_laneq_f32(_sum, _w0, _val, 0); - _sum = vfmaq_laneq_f32(_sum, _w1, _val, 1); - _sum = vfmaq_laneq_f32(_sum, _w2, _val, 2); - _sum = vfmaq_laneq_f32(_sum, _w3, _val, 3); - - sptr += dilation_w * 4; - kptr += 16; - } - } - - _sum = activation_ps(_sum, activation_type, activation_params); - - vst1_f16(outptr, vcvt_f16_f32(_sum)); - outptr += 4; - } - } - } - } - - if (elempack == 1 && out_elempack == 4) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - __fp16* outptr = top_blob.row<__fp16>(p); - - for (int j = 0; j < outw; j++) - { - float32x4_t _sum = vdupq_n_f32(0.f); - - if (bias_term) - { - _sum = vld1q_f32((const float*)bias_data + p * 4); - } - - const __fp16* kptr = weight_data_fp16.channel(p); - - for (int q = 0; q < h; q++) - { - const __fp16* sptr = bottom_blob_bordered.row(q) + j * stride_w; - - for (int k = 0; k < kernel_w; k++) - { - float32x4_t _val = vcvt_f32_f16(vdup_n_f16(sptr[0])); - float32x4_t _w = vcvt_f32_f16(vld1_f16(kptr)); - _sum = vfmaq_f32(_sum, _val, _w); - - sptr += dilation_w; - kptr += 4; - } - } - - _sum = activation_ps(_sum, activation_type, activation_params); - - vst1_f16(outptr, vcvt_f16_f32(_sum)); - outptr += 4; - } - } - } - } - - if (elempack == 4 && out_elempack == 1) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - __fp16* outptr = top_blob.row<__fp16>(p); - - for (int j = 0; j < outw; j++) - { - float sum = 0.f; - - if (bias_term) - { - sum = bias_data[p]; - } - - const __fp16* kptr = weight_data_fp16.channel(p); - - for (int q = 0; q < h; q++) - { - const __fp16* sptr = bottom_blob_bordered.row(q) + j * stride_w * 4; - - for (int k = 0; k < kernel_w; k++) - { - float32x4_t _val = vcvt_f32_f16(vld1_f16(sptr)); - float32x4_t _w = vcvt_f32_f16(vld1_f16(kptr)); - float32x4_t _s4 = vmulq_f32(_val, _w); - - sum += vaddvq_f32(_s4); // dot - - sptr += dilation_w * 4; - kptr += 4; - } - } - - sum = activation_ss(sum, activation_type, activation_params); - - outptr[j] = (__fp16)sum; - } - } - } - } - - if (elempack == 1 && out_elempack == 1) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - __fp16* outptr = top_blob.row<__fp16>(p); - - for (int j = 0; j < outw; j++) - { - float sum = 0.f; - - if (bias_term) - { - sum = bias_data[p]; - } - - const __fp16* kptr = weight_data_fp16.channel(p); - - for (int q = 0; q < h; q++) - { - const __fp16* sptr = bottom_blob_bordered.row(q) + j * stride_w; - - for (int k = 0; k < kernel_w; k++) - { - float val = (float)sptr[0]; - float w = (float)kptr[0]; - sum += val * w; - - sptr += dilation_w; - kptr += 1; - } - } - - sum = activation_ss(sum, activation_type, activation_params); - - outptr[j] = (__fp16)sum; - } - } - } - } + convolution1d_packed_fp16s(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, kernel_w, dilation_w, stride_w, activation_type, activation_params, opt); return 0; } @@ -288,7 +72,6 @@ int Convolution1D_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, cons int Convolution1D_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { int w = bottom_blob.w; - int h = bottom_blob.h; size_t elemsize = bottom_blob.elemsize; int elempack = bottom_blob.elempack; @@ -300,7 +83,6 @@ int Convolution1D_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, con return -100; w = bottom_blob_bordered.w; - h = bottom_blob_bordered.h; int out_elempack = 1; if (opt.use_packing_layout) @@ -316,442 +98,7 @@ int Convolution1D_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, con if (top_blob.empty()) return -100; - if (elempack == 8 && out_elempack == 8) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - __fp16* outptr = top_blob.row<__fp16>(p); - - for (int j = 0; j < outw; j++) - { - float16x8_t _sum = vdupq_n_f16((__fp16)0.f); - - if (bias_term) - { - _sum = vld1q_f16((const __fp16*)bias_data_fp16 + p * 8); - } - - const __fp16* kptr = weight_data_fp16.channel(p); - - for (int q = 0; q < h; q++) - { - const __fp16* sptr = bottom_blob_bordered.row(q) + j * stride_w * 8; - - for (int k = 0; k < kernel_w; k++) - { - float16x8_t _val = vld1q_f16(sptr); - - float16x8_t _w0 = vld1q_f16(kptr); - float16x8_t _w1 = vld1q_f16(kptr + 8); - float16x8_t _w2 = vld1q_f16(kptr + 16); - float16x8_t _w3 = vld1q_f16(kptr + 24); - float16x8_t _w4 = vld1q_f16(kptr + 32); - float16x8_t _w5 = vld1q_f16(kptr + 40); - float16x8_t _w6 = vld1q_f16(kptr + 48); - float16x8_t _w7 = vld1q_f16(kptr + 56); - - _sum = vfmaq_laneq_f16(_sum, _w0, _val, 0); - _sum = vfmaq_laneq_f16(_sum, _w1, _val, 1); - _sum = vfmaq_laneq_f16(_sum, _w2, _val, 2); - _sum = vfmaq_laneq_f16(_sum, _w3, _val, 3); - _sum = vfmaq_laneq_f16(_sum, _w4, _val, 4); - _sum = vfmaq_laneq_f16(_sum, _w5, _val, 5); - _sum = vfmaq_laneq_f16(_sum, _w6, _val, 6); - _sum = vfmaq_laneq_f16(_sum, _w7, _val, 7); - - sptr += dilation_w * 8; - kptr += 64; - } - } - - _sum = activation_ps(_sum, activation_type, activation_params); - - vst1q_f16(outptr, _sum); - outptr += 8; - } - } - } - } - - if (elempack == 1 && out_elempack == 8) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - __fp16* outptr = top_blob.row<__fp16>(p); - - for (int j = 0; j < outw; j++) - { - float16x8_t _sum = vdupq_n_f16((__fp16)0.f); - - if (bias_term) - { - _sum = vld1q_f16((const __fp16*)bias_data_fp16 + p * 8); - } - - const __fp16* kptr = weight_data_fp16.channel(p); - - for (int q = 0; q < h; q++) - { - const __fp16* sptr = bottom_blob_bordered.row(q) + j * stride_w; - - for (int k = 0; k < kernel_w; k++) - { - float16x8_t _val = vdupq_n_f16(sptr[0]); - float16x8_t _w = vld1q_f16(kptr); - _sum = vfmaq_f16(_sum, _val, _w); - - sptr += dilation_w; - kptr += 8; - } - } - - _sum = activation_ps(_sum, activation_type, activation_params); - - vst1q_f16(outptr, _sum); - outptr += 8; - } - } - } - } - - if (elempack == 4 && out_elempack == 8) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - __fp16* outptr = top_blob.row<__fp16>(p); - - for (int j = 0; j < outw; j++) - { - float16x8_t _sum = vdupq_n_f16((__fp16)0.f); - - if (bias_term) - { - _sum = vld1q_f16((const __fp16*)bias_data_fp16 + p * 8); - } - - const __fp16* kptr = weight_data_fp16.channel(p); - - for (int q = 0; q < h; q++) - { - const __fp16* sptr = bottom_blob_bordered.row(q) + j * stride_w * 4; - - for (int k = 0; k < kernel_w; k++) - { - float16x4_t _val = vld1_f16(sptr); - - float16x8_t _w0 = vld1q_f16(kptr); - float16x8_t _w1 = vld1q_f16(kptr + 8); - float16x8_t _w2 = vld1q_f16(kptr + 16); - float16x8_t _w3 = vld1q_f16(kptr + 24); - - _sum = vfmaq_lane_f16(_sum, _w0, _val, 0); - _sum = vfmaq_lane_f16(_sum, _w1, _val, 1); - _sum = vfmaq_lane_f16(_sum, _w2, _val, 2); - _sum = vfmaq_lane_f16(_sum, _w3, _val, 3); - - sptr += dilation_w * 4; - kptr += 32; - } - } - - _sum = activation_ps(_sum, activation_type, activation_params); - - vst1q_f16(outptr, _sum); - outptr += 8; - } - } - } - } - - if (elempack == 8 && out_elempack == 1) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - __fp16* outptr = top_blob.row<__fp16>(p); - - for (int j = 0; j < outw; j++) - { - float sum = 0.f; - - if (bias_term) - { - sum = ((const __fp16*)bias_data_fp16)[p]; - } - - const __fp16* kptr = weight_data_fp16.channel(p); - - for (int q = 0; q < h; q++) - { - const __fp16* sptr = bottom_blob_bordered.row(q) + j * stride_w * 8; - - for (int k = 0; k < kernel_w; k++) - { - float16x8_t _val = vld1q_f16(sptr); - float16x8_t _w = vld1q_f16(kptr); - float16x8_t _s8 = vmulq_f16(_val, _w); - - float16x4_t _s4 = vadd_f16(vget_low_f16(_s8), vget_high_f16(_s8)); - sum += vaddvq_f32(vcvt_f32_f16(_s4)); // dot - - sptr += dilation_w * 8; - kptr += 8; - } - } - - sum = activation_ss(sum, activation_type, activation_params); - - outptr[j] = sum; - } - } - } - } - - if (elempack == 8 && out_elempack == 4) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - __fp16* outptr = top_blob.row<__fp16>(p); - - for (int j = 0; j < outw; j++) - { - float16x4_t _sum = vdup_n_f16((__fp16)0.f); - - if (bias_term) - { - _sum = vld1_f16((const __fp16*)bias_data_fp16 + p * 4); - } - - const __fp16* kptr = weight_data_fp16.channel(p); - - for (int q = 0; q < h; q++) - { - const __fp16* sptr = bottom_blob_bordered.row(q) + j * stride_w * 8; - - for (int k = 0; k < kernel_w; k++) - { - float16x8_t _val = vld1q_f16(sptr); - - float16x4_t _w0 = vld1_f16(kptr); - float16x4_t _w1 = vld1_f16(kptr + 4); - float16x4_t _w2 = vld1_f16(kptr + 8); - float16x4_t _w3 = vld1_f16(kptr + 12); - float16x4_t _w4 = vld1_f16(kptr + 16); - float16x4_t _w5 = vld1_f16(kptr + 20); - float16x4_t _w6 = vld1_f16(kptr + 24); - float16x4_t _w7 = vld1_f16(kptr + 28); - - _sum = vfma_laneq_f16(_sum, _w0, _val, 0); - _sum = vfma_laneq_f16(_sum, _w1, _val, 1); - _sum = vfma_laneq_f16(_sum, _w2, _val, 2); - _sum = vfma_laneq_f16(_sum, _w3, _val, 3); - _sum = vfma_laneq_f16(_sum, _w4, _val, 4); - _sum = vfma_laneq_f16(_sum, _w5, _val, 5); - _sum = vfma_laneq_f16(_sum, _w6, _val, 6); - _sum = vfma_laneq_f16(_sum, _w7, _val, 7); - - sptr += dilation_w * 8; - kptr += 32; - } - } - - _sum = activation_ps(_sum, activation_type, activation_params); - - vst1_f16(outptr, _sum); - outptr += 4; - } - } - } - } - - if (elempack == 4 && out_elempack == 4) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - __fp16* outptr = top_blob.row<__fp16>(p); - - for (int j = 0; j < outw; j++) - { - float16x4_t _sum = vdup_n_f16((__fp16)0.f); - - if (bias_term) - { - _sum = vld1_f16((const __fp16*)bias_data_fp16 + p * 4); - } - - const __fp16* kptr = weight_data_fp16.channel(p); - - for (int q = 0; q < h; q++) - { - const __fp16* sptr = bottom_blob_bordered.row(q) + j * stride_w * 4; - - for (int k = 0; k < kernel_w; k++) - { - float16x4_t _val = vld1_f16(sptr); - - float16x4_t _w0 = vld1_f16(kptr); - float16x4_t _w1 = vld1_f16(kptr + 4); - float16x4_t _w2 = vld1_f16(kptr + 8); - float16x4_t _w3 = vld1_f16(kptr + 12); - - _sum = vfma_lane_f16(_sum, _w0, _val, 0); - _sum = vfma_lane_f16(_sum, _w1, _val, 1); - _sum = vfma_lane_f16(_sum, _w2, _val, 2); - _sum = vfma_lane_f16(_sum, _w3, _val, 3); - - sptr += dilation_w * 4; - kptr += 16; - } - } - - _sum = activation_ps(_sum, activation_type, activation_params); - - vst1_f16(outptr, _sum); - outptr += 4; - } - } - } - } - - if (elempack == 1 && out_elempack == 4) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - __fp16* outptr = top_blob.row<__fp16>(p); - - for (int j = 0; j < outw; j++) - { - float16x4_t _sum = vdup_n_f16((__fp16)0.f); - - if (bias_term) - { - _sum = vld1_f16((const __fp16*)bias_data_fp16 + p * 4); - } - - const __fp16* kptr = weight_data_fp16.channel(p); - - for (int q = 0; q < h; q++) - { - const __fp16* sptr = bottom_blob_bordered.row(q) + j * stride_w; - - for (int k = 0; k < kernel_w; k++) - { - float16x4_t _val = vdup_n_f16(sptr[0]); - float16x4_t _w = vld1_f16(kptr); - _sum = vfma_f16(_sum, _val, _w); - - sptr += dilation_w; - kptr += 4; - } - } - - _sum = activation_ps(_sum, activation_type, activation_params); - - vst1_f16(outptr, _sum); - outptr += 4; - } - } - } - } - - if (elempack == 4 && out_elempack == 1) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - __fp16* outptr = top_blob.row<__fp16>(p); - - for (int j = 0; j < outw; j++) - { - float sum = 0.f; - - if (bias_term) - { - sum = ((const __fp16*)bias_data_fp16)[p]; - } - - const __fp16* kptr = weight_data_fp16.channel(p); - - for (int q = 0; q < h; q++) - { - const __fp16* sptr = bottom_blob_bordered.row(q) + j * stride_w * 4; - - for (int k = 0; k < kernel_w; k++) - { - float16x4_t _val = vld1_f16(sptr); - float16x4_t _w = vld1_f16(kptr); - float16x4_t _s4 = vmul_f16(_val, _w); - - sum += vaddvq_f32(vcvt_f32_f16(_s4)); // dot - - sptr += dilation_w * 4; - kptr += 4; - } - } - - sum = activation_ss(sum, activation_type, activation_params); - - outptr[j] = sum; - } - } - } - } - - if (elempack == 1 && out_elempack == 1) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - __fp16* outptr = top_blob.row<__fp16>(p); - - for (int j = 0; j < outw; j++) - { - float sum = 0.f; - - if (bias_term) - { - sum = bias_data[p]; - } - - const __fp16* kptr = weight_data_fp16.channel(p); - - for (int q = 0; q < h; q++) - { - const __fp16* sptr = bottom_blob_bordered.row(q) + j * stride_w; - - for (int k = 0; k < kernel_w; k++) - { - float val = (float)sptr[0]; - float w = (float)kptr[0]; - sum += val * w; - - sptr += dilation_w; - kptr += 1; - } - } - - sum = activation_ss(sum, activation_type, activation_params); - - outptr[j] = (__fp16)sum; - } - } - } - } + convolution1d_packed_fp16sa(bottom_blob_bordered, top_blob, weight_data_tm, bias_data_fp16, kernel_w, dilation_w, stride_w, activation_type, activation_params, opt); return 0; } diff --git a/src/layer/arm/convolution1d_packed.h b/src/layer/arm/convolution1d_packed.h new file mode 100644 index 000000000..08a55f4fa --- /dev/null +++ b/src/layer/arm/convolution1d_packed.h @@ -0,0 +1,1272 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void convolution1d_transform_kernel_packed(const Mat& kernel, Mat& kernel_tm, int inh, int outh, int kernel_w) +{ + // src = kw-inh-outh + // dst = pb-pa-kw-inh/pa-outh/pb + + // clang-format off + // *INDENT-OFF* +#if __ARM_NEON +#if __aarch64__ + if (outh >= 8) + { + if (inh >= 8) + kernel_tm.create(8 * 8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2); + else if (inh >= 4) + kernel_tm.create(8 * 4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2); + else if (inh >= 2) + kernel_tm.create(8 * 2 * kernel_w, inh / 2 + inh % 2, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2); + else + kernel_tm.create(8 * kernel_w, inh, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2); + } + else +#endif // __aarch64__ + if (outh >= 4) + { +#if __aarch64__ + if (inh >= 8) + kernel_tm.create(4 * 8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 4 + (outh % 4) / 2 + outh % 2); + else +#endif // __aarch64__ + if (inh >= 4) + kernel_tm.create(4 * 4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh / 4 + (outh % 4) / 2 + outh % 2); + else if (inh >= 2) + kernel_tm.create(4 * 2 * kernel_w, inh / 2 + inh % 2, outh / 4 + (outh % 4) / 2 + outh % 2); + else + kernel_tm.create(4 * kernel_w, inh, outh / 4 + (outh % 4) / 2 + outh % 2); + } + else +#endif // __ARM_NEON + if (outh >= 2) + { +#if __ARM_NEON +#if __aarch64__ + if (inh >= 8) + kernel_tm.create(2 * 8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 2 + outh % 2); + else +#endif // __aarch64__ + if (inh >= 4) + kernel_tm.create(2 * 4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh / 2 + outh % 2); + else if (inh >= 2) + kernel_tm.create(2 * 2 * kernel_w, inh / 2 + inh % 2, outh / 2 + outh % 2); + else +#endif // __ARM_NEON + kernel_tm.create(2 * kernel_w, inh, outh / 2 + outh % 2); + } + else + { +#if __ARM_NEON +#if __aarch64__ + if (inh >= 8) + kernel_tm.create(8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh); + else +#endif // __aarch64__ + if (inh >= 4) + kernel_tm.create(4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh); + else if (inh >= 2) + kernel_tm.create(2 * kernel_w, inh / 2 + inh % 2, outh); + else +#endif // __ARM_NEON + kernel_tm.create(kernel_w, inh, outh); + } + // *INDENT-ON* + // clang-format on + + int q = 0; +#if __ARM_NEON +#if __aarch64__ + for (; q + 7 < outh; q += 8) + { + const float* kptr0 = (const float*)kernel + q * inh * kernel_w; + const float* kptr1 = (const float*)kernel + (q + 1) * inh * kernel_w; + const float* kptr2 = (const float*)kernel + (q + 2) * inh * kernel_w; + const float* kptr3 = (const float*)kernel + (q + 3) * inh * kernel_w; + const float* kptr4 = (const float*)kernel + (q + 4) * inh * kernel_w; + const float* kptr5 = (const float*)kernel + (q + 5) * inh * kernel_w; + const float* kptr6 = (const float*)kernel + (q + 6) * inh * kernel_w; + const float* kptr7 = (const float*)kernel + (q + 7) * inh * kernel_w; + + float* g00 = kernel_tm.channel(q / 8); + + int p = 0; + for (; p + 7 < inh; p += 8) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + + for (int i = 0; i < 8; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + g00[4] = k4[k]; + g00[5] = k5[k]; + g00[6] = k6[k]; + g00[7] = k7[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + k4 += kernel_w; + k5 += kernel_w; + k6 += kernel_w; + k7 += kernel_w; + g00 += 8; + } + } + } + for (; p + 3 < inh; p += 4) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + + for (int i = 0; i < 4; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + g00[4] = k4[k]; + g00[5] = k5[k]; + g00[6] = k6[k]; + g00[7] = k7[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + k4 += kernel_w; + k5 += kernel_w; + k6 += kernel_w; + k7 += kernel_w; + g00 += 8; + } + } + } + for (; p + 1 < inh; p += 2) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + + for (int i = 0; i < 2; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + g00[4] = k4[k]; + g00[5] = k5[k]; + g00[6] = k6[k]; + g00[7] = k7[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + k4 += kernel_w; + k5 += kernel_w; + k6 += kernel_w; + k7 += kernel_w; + g00 += 8; + } + } + } + for (; p < inh; p++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + + for (int k = 0; k < kernel_w; k++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + g00[4] = k4[k]; + g00[5] = k5[k]; + g00[6] = k6[k]; + g00[7] = k7[k]; + g00 += 8; + } + } + } +#endif // __aarch64__ + for (; q + 3 < outh; q += 4) + { + const float* kptr0 = (const float*)kernel + q * inh * kernel_w; + const float* kptr1 = (const float*)kernel + (q + 1) * inh * kernel_w; + const float* kptr2 = (const float*)kernel + (q + 2) * inh * kernel_w; + const float* kptr3 = (const float*)kernel + (q + 3) * inh * kernel_w; + +#if __aarch64__ + float* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4); +#else + float* g00 = kernel_tm.channel(q / 4); +#endif + + int p = 0; +#if __aarch64__ + for (; p + 7 < inh; p += 8) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + + for (int i = 0; i < 8; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + g00 += 4; + } + } + } +#endif // __aarch64__ + for (; p + 3 < inh; p += 4) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + + for (int i = 0; i < 4; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + g00 += 4; + } + } + } + for (; p + 1 < inh; p += 2) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + + for (int i = 0; i < 2; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + g00 += 4; + } + } + } + for (; p < inh; p++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + + for (int k = 0; k < kernel_w; k++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + g00 += 4; + } + } + } +#endif // __ARM_NEON + for (; q + 1 < outh; q += 2) + { + const float* kptr0 = (const float*)kernel + q * inh * kernel_w; + const float* kptr1 = (const float*)kernel + (q + 1) * inh * kernel_w; + +#if __aarch64__ + float* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4 + (q % 4) / 2); +#elif __ARM_NEON + float* g00 = kernel_tm.channel(q / 4 + (q % 4) / 2); +#else + float* g00 = kernel_tm.channel(q / 2); +#endif + + int p = 0; +#if __ARM_NEON +#if __aarch64__ + for (; p + 7 < inh; p += 8) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w + k; + const float* k1 = kptr1 + p * kernel_w + k; + + g00[0] = k0[0]; + g00[1] = k0[kernel_w]; + g00[2] = k0[kernel_w * 2]; + g00[3] = k0[kernel_w * 3]; + g00[4] = k0[kernel_w * 4]; + g00[5] = k0[kernel_w * 5]; + g00[6] = k0[kernel_w * 6]; + g00[7] = k0[kernel_w * 7]; + g00[8] = k1[0]; + g00[9] = k1[kernel_w]; + g00[10] = k1[kernel_w * 2]; + g00[11] = k1[kernel_w * 3]; + g00[12] = k1[kernel_w * 4]; + g00[13] = k1[kernel_w * 5]; + g00[14] = k1[kernel_w * 6]; + g00[15] = k1[kernel_w * 7]; + g00 += 16; + } + } +#endif // __aarch64__ + for (; p + 3 < inh; p += 4) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w + k; + const float* k1 = kptr1 + p * kernel_w + k; + + g00[0] = k0[0]; + g00[1] = k0[kernel_w]; + g00[2] = k0[kernel_w * 2]; + g00[3] = k0[kernel_w * 3]; + g00[4] = k1[0]; + g00[5] = k1[kernel_w]; + g00[6] = k1[kernel_w * 2]; + g00[7] = k1[kernel_w * 3]; + g00 += 8; + } + } +#endif // __ARM_NEON + for (; p + 1 < inh; p += 2) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + + for (int i = 0; i < 2; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + k0 += kernel_w; + k1 += kernel_w; + g00 += 2; + } + } + } + for (; p < inh; p++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + + for (int k = 0; k < kernel_w; k++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00 += 2; + } + } + } + for (; q < outh; q++) + { + const float* kptr = (const float*)kernel + q * inh * kernel_w; + +#if __aarch64__ + float* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4 + (q % 4) / 2 + q % 2); +#elif __ARM_NEON + float* g00 = kernel_tm.channel(q / 4 + (q % 4) / 2 + q % 2); +#else + float* g00 = kernel_tm.channel(q / 2 + q % 2); +#endif + + int p = 0; +#if __ARM_NEON +#if __aarch64__ + for (; p + 7 < inh; p += 8) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr + p * kernel_w; + + for (int i = 0; i < 8; i++) + { + g00[0] = k0[k]; + k0 += kernel_w; + g00 += 1; + } + } + } +#endif // __aarch64__ + for (; p + 3 < inh; p += 4) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr + p * kernel_w; + + for (int i = 0; i < 4; i++) + { + g00[0] = k0[k]; + k0 += kernel_w; + g00 += 1; + } + } + } +#endif // __ARM_NEON + for (; p + 1 < inh; p += 2) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr + p * kernel_w; + + for (int i = 0; i < 2; i++) + { + g00[0] = k0[k]; + k0 += kernel_w; + g00 += 1; + } + } + } + for (; p < inh; p++) + { + const float* k0 = kptr + p * kernel_w; + + for (int k = 0; k < kernel_w; k++) + { + g00[0] = k0[k]; + g00++; + } + } + } +} + +static void convolution1d_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, const Mat& bias_data, int kernel_w, int dilation_w, int stride_w, int activation_type, const Mat& activation_params, const Option& opt) +{ + const int elempack = bottom_blob.elempack; + const int inh = bottom_blob.h * elempack; + + const int N = bottom_blob.w * elempack; + + const int outw = top_blob.w; + const int out_elempack = top_blob.elempack; + const int outh = top_blob.h * out_elempack; + + const int M = top_blob.w * out_elempack; + + const float* bias_data_ptr = bias_data; + + int nn_outh = 0; + int remain_outh_start = 0; +#if __ARM_NEON +#if __aarch64__ + nn_outh = (outh - remain_outh_start) / 8; + #pragma omp parallel for num_threads(opt.num_threads) + for (int pp = 0; pp < nn_outh; pp++) + { + const int p = remain_outh_start + pp * 8; + + float* outptr = top_blob.row(p / out_elempack); + + for (int j = 0; j < outw; j++) + { + float32x4_t _sum0 = vdupq_n_f32(0.f); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + float32x4_t _sum4 = vdupq_n_f32(0.f); + float32x4_t _sum5 = vdupq_n_f32(0.f); + float32x4_t _sum6 = vdupq_n_f32(0.f); + float32x4_t _sum7 = vdupq_n_f32(0.f); + + if (bias_data_ptr) + { + _sum0 = vld1q_f32(bias_data_ptr + p); + _sum1 = vld1q_f32(bias_data_ptr + p + 4); + } + + const float* kptr = weight_data_tm.channel(p / 8); + + int q = 0; + for (; q + 7 < inh; q += 8) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + float32x4_t _r1; + if (elempack == 4) + { + _r0 = vld1q_f32(r0); + _r1 = vld1q_f32(r0 + N); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + _r0 = vsetq_lane_f32(r0[0], _r0, 0); + _r0 = vsetq_lane_f32(r0[N], _r0, 1); + _r0 = vsetq_lane_f32(r0[N * 2], _r0, 2); + _r0 = vsetq_lane_f32(r0[N * 3], _r0, 3); + _r1 = vsetq_lane_f32(r0[N * 4], _r1, 0); + _r1 = vsetq_lane_f32(r0[N * 5], _r1, 1); + _r1 = vsetq_lane_f32(r0[N * 6], _r1, 2); + _r1 = vsetq_lane_f32(r0[N * 7], _r1, 3); + r0 += dilation_w; + } + + float32x4_t _w0 = vld1q_f32(kptr); + float32x4_t _w1 = vld1q_f32(kptr + 4); + float32x4_t _w2 = vld1q_f32(kptr + 4 * 2); + float32x4_t _w3 = vld1q_f32(kptr + 4 * 3); + float32x4_t _w4 = vld1q_f32(kptr + 4 * 4); + float32x4_t _w5 = vld1q_f32(kptr + 4 * 5); + float32x4_t _w6 = vld1q_f32(kptr + 4 * 6); + float32x4_t _w7 = vld1q_f32(kptr + 4 * 7); + float32x4_t _w8 = vld1q_f32(kptr + 4 * 8); + float32x4_t _w9 = vld1q_f32(kptr + 4 * 9); + float32x4_t _wa = vld1q_f32(kptr + 4 * 10); + float32x4_t _wb = vld1q_f32(kptr + 4 * 11); + float32x4_t _wc = vld1q_f32(kptr + 4 * 12); + float32x4_t _wd = vld1q_f32(kptr + 4 * 13); + float32x4_t _we = vld1q_f32(kptr + 4 * 14); + float32x4_t _wf = vld1q_f32(kptr + 4 * 15); + _sum0 = vfmaq_laneq_f32(_sum0, _w0, _r0, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w1, _r0, 0); + _sum2 = vfmaq_laneq_f32(_sum2, _w2, _r0, 1); + _sum3 = vfmaq_laneq_f32(_sum3, _w3, _r0, 1); + _sum4 = vfmaq_laneq_f32(_sum4, _w4, _r0, 2); + _sum5 = vfmaq_laneq_f32(_sum5, _w5, _r0, 2); + _sum6 = vfmaq_laneq_f32(_sum6, _w6, _r0, 3); + _sum7 = vfmaq_laneq_f32(_sum7, _w7, _r0, 3); + _sum0 = vfmaq_laneq_f32(_sum0, _w8, _r1, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w9, _r1, 0); + _sum2 = vfmaq_laneq_f32(_sum2, _wa, _r1, 1); + _sum3 = vfmaq_laneq_f32(_sum3, _wb, _r1, 1); + _sum4 = vfmaq_laneq_f32(_sum4, _wc, _r1, 2); + _sum5 = vfmaq_laneq_f32(_sum5, _wd, _r1, 2); + _sum6 = vfmaq_laneq_f32(_sum6, _we, _r1, 3); + _sum7 = vfmaq_laneq_f32(_sum7, _wf, _r1, 3); + + kptr += 64; + } + } + for (; q + 3 < inh; q += 4) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + if (elempack == 4) + { + _r0 = vld1q_f32(r0); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + _r0 = float32x4_t(); + _r0 = vsetq_lane_f32(r0[0], _r0, 0); + _r0 = vsetq_lane_f32(r0[N], _r0, 1); + _r0 = vsetq_lane_f32(r0[N * 2], _r0, 2); + _r0 = vsetq_lane_f32(r0[N * 3], _r0, 3); + r0 += dilation_w; + } + + float32x4_t _w0 = vld1q_f32(kptr); + float32x4_t _w1 = vld1q_f32(kptr + 4); + float32x4_t _w2 = vld1q_f32(kptr + 4 * 2); + float32x4_t _w3 = vld1q_f32(kptr + 4 * 3); + float32x4_t _w4 = vld1q_f32(kptr + 4 * 4); + float32x4_t _w5 = vld1q_f32(kptr + 4 * 5); + float32x4_t _w6 = vld1q_f32(kptr + 4 * 6); + float32x4_t _w7 = vld1q_f32(kptr + 4 * 7); + _sum0 = vfmaq_laneq_f32(_sum0, _w0, _r0, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w1, _r0, 0); + _sum2 = vfmaq_laneq_f32(_sum2, _w2, _r0, 1); + _sum3 = vfmaq_laneq_f32(_sum3, _w3, _r0, 1); + _sum4 = vfmaq_laneq_f32(_sum4, _w4, _r0, 2); + _sum5 = vfmaq_laneq_f32(_sum5, _w5, _r0, 2); + _sum6 = vfmaq_laneq_f32(_sum6, _w6, _r0, 3); + _sum7 = vfmaq_laneq_f32(_sum7, _w7, _r0, 3); + + kptr += 32; + } + } + for (; q + 1 < inh; q += 2) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val0; + float val1; + // if (elempack == 1) + { + val0 = r0[0]; + val1 = r0[N]; + r0 += dilation_w; + } + + float32x4_t _w0 = vld1q_f32(kptr); + float32x4_t _w1 = vld1q_f32(kptr + 4); + float32x4_t _w2 = vld1q_f32(kptr + 8); + float32x4_t _w3 = vld1q_f32(kptr + 12); + _sum0 = vfmaq_n_f32(_sum0, _w0, val0); + _sum1 = vfmaq_n_f32(_sum1, _w1, val0); + _sum2 = vfmaq_n_f32(_sum2, _w2, val1); + _sum3 = vfmaq_n_f32(_sum3, _w3, val1); + + kptr += 16; + } + } + for (; q < inh; q++) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _val; + // if (elempack == 1) + { + _val = vdupq_n_f32(r0[0]); + r0 += dilation_w; + } + + float32x4_t _w0 = vld1q_f32(kptr); + float32x4_t _w1 = vld1q_f32(kptr + 4); + _sum0 = vfmaq_f32(_sum0, _w0, _val); + _sum1 = vfmaq_f32(_sum1, _w1, _val); + + kptr += 8; + } + } + + _sum0 = vaddq_f32(_sum0, _sum2); + _sum1 = vaddq_f32(_sum1, _sum3); + _sum4 = vaddq_f32(_sum4, _sum6); + _sum5 = vaddq_f32(_sum5, _sum7); + _sum0 = vaddq_f32(_sum0, _sum4); + _sum1 = vaddq_f32(_sum1, _sum5); + + _sum0 = activation_ps(_sum0, activation_type, activation_params); + _sum1 = activation_ps(_sum1, activation_type, activation_params); + + if (out_elempack == 4) + { + vst1q_f32(outptr, _sum0); + vst1q_f32(outptr + M, _sum1); + outptr += 4; + } + if (out_elempack == 1) + { + outptr[0] = vgetq_lane_f32(_sum0, 0); + outptr[M] = vgetq_lane_f32(_sum0, 1); + outptr[M * 2] = vgetq_lane_f32(_sum0, 2); + outptr[M * 3] = vgetq_lane_f32(_sum0, 3); + outptr[M * 4] = vgetq_lane_f32(_sum1, 0); + outptr[M * 5] = vgetq_lane_f32(_sum1, 1); + outptr[M * 6] = vgetq_lane_f32(_sum1, 2); + outptr[M * 7] = vgetq_lane_f32(_sum1, 3); + outptr += 1; + } + } + } + remain_outh_start += nn_outh * 8; + nn_outh = (outh - remain_outh_start) / 4; +#else // __aarch64__ + nn_outh = (outh - remain_outh_start) / 4; + #pragma omp parallel for num_threads(opt.num_threads) +#endif // __aarch64__ + for (int pp = 0; pp < nn_outh; pp++) + { + const int p = remain_outh_start + pp * 4; + + float* outptr = top_blob.row(p / out_elempack); + + for (int j = 0; j < outw; j++) + { + float32x4_t _sum0 = vdupq_n_f32(0.f); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + + if (bias_data_ptr) + { + _sum0 = vld1q_f32(bias_data_ptr + p); + } + +#if __aarch64__ + const float* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4); +#else + const float* kptr = weight_data_tm.channel(p / 4); +#endif + + int q = 0; +#if __aarch64__ + for (; q + 7 < inh; q += 8) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + float32x4_t _r1; + if (elempack == 4) + { + _r0 = vld1q_f32(r0); + _r1 = vld1q_f32(r0 + N); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + _r0 = float32x4_t(); + _r1 = float32x4_t(); + _r0 = vsetq_lane_f32(r0[0], _r0, 0); + _r0 = vsetq_lane_f32(r0[N], _r0, 1); + _r0 = vsetq_lane_f32(r0[N * 2], _r0, 2); + _r0 = vsetq_lane_f32(r0[N * 3], _r0, 3); + _r1 = vsetq_lane_f32(r0[N * 4], _r1, 0); + _r1 = vsetq_lane_f32(r0[N * 5], _r1, 1); + _r1 = vsetq_lane_f32(r0[N * 6], _r1, 2); + _r1 = vsetq_lane_f32(r0[N * 7], _r1, 3); + r0 += dilation_w; + } + + float32x4_t _w0 = vld1q_f32(kptr); + float32x4_t _w1 = vld1q_f32(kptr + 4); + float32x4_t _w2 = vld1q_f32(kptr + 8); + float32x4_t _w3 = vld1q_f32(kptr + 12); + float32x4_t _w4 = vld1q_f32(kptr + 16); + float32x4_t _w5 = vld1q_f32(kptr + 20); + float32x4_t _w6 = vld1q_f32(kptr + 24); + float32x4_t _w7 = vld1q_f32(kptr + 28); + _sum0 = vfmaq_laneq_f32(_sum0, _w0, _r0, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w1, _r0, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _w2, _r0, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _w3, _r0, 3); + _sum0 = vfmaq_laneq_f32(_sum0, _w4, _r1, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w5, _r1, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _w6, _r1, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _w7, _r1, 3); + + kptr += 32; + } + } +#endif // __aarch64__ + for (; q + 3 < inh; q += 4) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + if (elempack == 4) + { + _r0 = vld1q_f32(r0); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + _r0 = float32x4_t(); + _r0 = vsetq_lane_f32(r0[0], _r0, 0); + _r0 = vsetq_lane_f32(r0[N], _r0, 1); + _r0 = vsetq_lane_f32(r0[N * 2], _r0, 2); + _r0 = vsetq_lane_f32(r0[N * 3], _r0, 3); + r0 += dilation_w; + } + + float32x4_t _w0 = vld1q_f32(kptr); + float32x4_t _w1 = vld1q_f32(kptr + 4); + float32x4_t _w2 = vld1q_f32(kptr + 8); + float32x4_t _w3 = vld1q_f32(kptr + 12); +#if __aarch64__ + _sum0 = vfmaq_laneq_f32(_sum0, _w0, _r0, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w1, _r0, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _w2, _r0, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _w3, _r0, 3); +#else + _sum0 = vmlaq_lane_f32(_sum0, _w0, vget_low_f32(_r0), 0); + _sum1 = vmlaq_lane_f32(_sum1, _w1, vget_low_f32(_r0), 1); + _sum2 = vmlaq_lane_f32(_sum2, _w2, vget_high_f32(_r0), 0); + _sum3 = vmlaq_lane_f32(_sum3, _w3, vget_high_f32(_r0), 1); +#endif + + kptr += 16; + } + } + for (; q + 1 < inh; q += 2) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val0; + float val1; + // if (elempack == 1) + { + val0 = r0[0]; + val1 = r0[N]; + r0 += dilation_w; + } + + float32x4_t _w0 = vld1q_f32(kptr); + float32x4_t _w1 = vld1q_f32(kptr + 4); +#if __aarch64__ + _sum0 = vfmaq_n_f32(_sum0, _w0, val0); + _sum1 = vfmaq_n_f32(_sum1, _w1, val1); +#else + _sum0 = vmlaq_n_f32(_sum0, _w0, val0); + _sum1 = vmlaq_n_f32(_sum1, _w1, val1); +#endif + + kptr += 8; + } + } + for (; q < inh; q++) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _val; + // if (elempack == 1) + { + _val = vdupq_n_f32(r0[0]); + r0 += dilation_w; + } + + float32x4_t _w = vld1q_f32(kptr); +#if __aarch64__ + _sum0 = vfmaq_f32(_sum0, _val, _w); +#else + _sum0 = vmlaq_f32(_sum0, _val, _w); +#endif + + kptr += 4; + } + } + + _sum0 = vaddq_f32(_sum0, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _sum0 = vaddq_f32(_sum0, _sum2); + + _sum0 = activation_ps(_sum0, activation_type, activation_params); + + if (out_elempack == 4) + { + vst1q_f32(outptr, _sum0); + outptr += 4; + } + if (out_elempack == 1) + { + outptr[0] = vgetq_lane_f32(_sum0, 0); + outptr[M] = vgetq_lane_f32(_sum0, 1); + outptr[M * 2] = vgetq_lane_f32(_sum0, 2); + outptr[M * 3] = vgetq_lane_f32(_sum0, 3); + outptr += 1; + } + } + } + remain_outh_start += nn_outh * 4; + nn_outh = (outh - remain_outh_start) / 2; +#else // __ARM_NEON + nn_outh = (outh - remain_outh_start) / 2; + #pragma omp parallel for num_threads(opt.num_threads) +#endif // __ARM_NEON + for (int pp = 0; pp < nn_outh; pp++) + { + const int p = remain_outh_start + pp * 2; + + float* outptr0 = top_blob.row(p); + float* outptr1 = top_blob.row(p + 1); + + for (int j = 0; j < outw; j++) + { + float sum0 = 0.f; + float sum1 = 0.f; + + if (bias_data_ptr) + { + sum0 = bias_data_ptr[p]; + sum1 = bias_data_ptr[p + 1]; + } + +#if __aarch64__ + const float* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4 + (p % 4) / 2); +#elif __ARM_NEON + const float* kptr = weight_data_tm.channel(p / 4 + (p % 4) / 2); +#else + const float* kptr = weight_data_tm.channel(p / 2); +#endif + + int q = 0; +#if __ARM_NEON +#if __aarch64__ + float32x4_t _sum0 = vdupq_n_f32(0.f); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + for (; q + 7 < inh; q += 8) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + float32x4_t _r1; + if (elempack == 4) + { + _r0 = vld1q_f32(r0); + _r1 = vld1q_f32(r0 + N); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + _r0 = float32x4_t(); + _r1 = float32x4_t(); + _r0 = vsetq_lane_f32(r0[0], _r0, 0); + _r0 = vsetq_lane_f32(r0[N], _r0, 1); + _r0 = vsetq_lane_f32(r0[N * 2], _r0, 2); + _r0 = vsetq_lane_f32(r0[N * 3], _r0, 3); + _r1 = vsetq_lane_f32(r0[N * 4], _r1, 0); + _r1 = vsetq_lane_f32(r0[N * 5], _r1, 1); + _r1 = vsetq_lane_f32(r0[N * 6], _r1, 2); + _r1 = vsetq_lane_f32(r0[N * 7], _r1, 3); + r0 += dilation_w; + } + + float32x4_t _w0 = vld1q_f32(kptr); + float32x4_t _w1 = vld1q_f32(kptr + 4); + float32x4_t _w2 = vld1q_f32(kptr + 8); + float32x4_t _w3 = vld1q_f32(kptr + 12); + _sum0 = vfmaq_f32(_sum0, _r0, _w0); + _sum1 = vfmaq_f32(_sum1, _r1, _w1); + _sum2 = vfmaq_f32(_sum2, _r0, _w2); + _sum3 = vfmaq_f32(_sum3, _r1, _w3); + + kptr += 16; + } + } + _sum0 = vaddq_f32(_sum0, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + sum0 += vaddvq_f32(_sum0); + sum1 += vaddvq_f32(_sum2); + _sum0 = vdupq_n_f32(0.f); + _sum1 = vdupq_n_f32(0.f); +#else // __aarch64__ + float32x4_t _sum0 = vdupq_n_f32(0.f); + float32x4_t _sum1 = vdupq_n_f32(0.f); +#endif // __aarch64__ + for (; q + 3 < inh; q += 4) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + if (elempack == 4) + { + _r0 = vld1q_f32(r0); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + _r0 = float32x4_t(); + _r0 = vsetq_lane_f32(r0[0], _r0, 0); + _r0 = vsetq_lane_f32(r0[N], _r0, 1); + _r0 = vsetq_lane_f32(r0[N * 2], _r0, 2); + _r0 = vsetq_lane_f32(r0[N * 3], _r0, 3); + r0 += dilation_w; + } + + float32x4_t _w0 = vld1q_f32(kptr); + float32x4_t _w1 = vld1q_f32(kptr + 4); +#if __aarch64__ + _sum0 = vfmaq_f32(_sum0, _r0, _w0); + _sum1 = vfmaq_f32(_sum1, _r0, _w1); +#else + _sum0 = vmlaq_f32(_sum0, _r0, _w0); + _sum1 = vmlaq_f32(_sum1, _r0, _w1); +#endif + + kptr += 8; + } + } +#if __aarch64__ + sum0 += vaddvq_f32(_sum0); + sum1 += vaddvq_f32(_sum1); +#else + float32x2_t _ss0 = vadd_f32(vget_low_f32(_sum0), vget_high_f32(_sum0)); + float32x2_t _ss1 = vadd_f32(vget_low_f32(_sum1), vget_high_f32(_sum1)); + float32x2_t _ss = vpadd_f32(_ss0, _ss1); + sum0 += vget_lane_f32(_ss, 0); + sum1 += vget_lane_f32(_ss, 1); +#endif +#endif // __ARM_NEON + for (; q + 1 < inh; q += 2) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val0; + float val1; + // if (elempack == 1) + { + val0 = r0[0]; + val1 = r0[N]; + r0 += dilation_w; + } + + sum0 += val0 * kptr[0]; + sum1 += val0 * kptr[1]; + sum0 += val1 * kptr[2]; + sum1 += val1 * kptr[3]; + + kptr += 4; + } + } + for (; q < inh; q++) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val; + // if (elempack == 1) + { + val = r0[0]; + r0 += dilation_w; + } + + sum0 += val * kptr[0]; + sum1 += val * kptr[1]; + + kptr += 2; + } + } + + sum0 = activation_ss(sum0, activation_type, activation_params); + sum1 = activation_ss(sum1, activation_type, activation_params); + + outptr0[0] = sum0; + outptr1[0] = sum1; + outptr0 += 1; + outptr1 += 1; + } + } + remain_outh_start += nn_outh * 2; + for (int p = remain_outh_start; p < outh; p++) + { + float* outptr = top_blob.row(p); + + for (int j = 0; j < outw; j++) + { + float sum = 0.f; + + if (bias_data_ptr) + { + sum = bias_data_ptr[p]; + } + +#if __aarch64__ + const float* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4 + (p % 4) / 2 + p % 2); +#elif __ARM_NEON + const float* kptr = weight_data_tm.channel(p / 4 + (p % 4) / 2 + p % 2); +#else + const float* kptr = weight_data_tm.channel(p / 2 + p % 2); +#endif + + int q = 0; +#if __ARM_NEON +#if __aarch64__ + float32x4_t _sum0 = vdupq_n_f32(0.f); + float32x4_t _sum1 = vdupq_n_f32(0.f); + for (; q + 7 < inh; q += 8) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + float32x4_t _r1; + if (elempack == 4) + { + _r0 = vld1q_f32(r0); + _r1 = vld1q_f32(r0 + N); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + _r0 = float32x4_t(); + _r1 = float32x4_t(); + _r0 = vsetq_lane_f32(r0[0], _r0, 0); + _r0 = vsetq_lane_f32(r0[N], _r0, 1); + _r0 = vsetq_lane_f32(r0[N * 2], _r0, 2); + _r0 = vsetq_lane_f32(r0[N * 3], _r0, 3); + _r1 = vsetq_lane_f32(r0[N * 4], _r1, 0); + _r1 = vsetq_lane_f32(r0[N * 5], _r1, 1); + _r1 = vsetq_lane_f32(r0[N * 6], _r1, 2); + _r1 = vsetq_lane_f32(r0[N * 7], _r1, 3); + r0 += dilation_w; + } + + float32x4_t _w0 = vld1q_f32(kptr); + float32x4_t _w1 = vld1q_f32(kptr + 4); + _sum0 = vfmaq_f32(_sum0, _r0, _w0); + _sum1 = vfmaq_f32(_sum1, _r1, _w1); + + kptr += 8; + } + } + _sum0 = vaddq_f32(_sum0, _sum1); + sum += vaddvq_f32(_sum0); +#endif // __aarch64__ + float32x4_t _sum = vdupq_n_f32(0.f); + for (; q + 3 < inh; q += 4) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + if (elempack == 4) + { + _r0 = vld1q_f32(r0); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + _r0 = float32x4_t(); + _r0 = vsetq_lane_f32(r0[0], _r0, 0); + _r0 = vsetq_lane_f32(r0[N], _r0, 1); + _r0 = vsetq_lane_f32(r0[N * 2], _r0, 2); + _r0 = vsetq_lane_f32(r0[N * 3], _r0, 3); + r0 += dilation_w; + } + + float32x4_t _w = vld1q_f32(kptr); +#if __aarch64__ + _sum = vfmaq_f32(_sum, _r0, _w); +#else + _sum = vmlaq_f32(_sum, _r0, _w); +#endif + + kptr += 4; + } + } +#if __aarch64__ + sum += vaddvq_f32(_sum); +#else + float32x2_t _ss = vadd_f32(vget_low_f32(_sum), vget_high_f32(_sum)); + _ss = vpadd_f32(_ss, _ss); + sum += vget_lane_f32(_ss, 0); +#endif +#endif // __ARM_NEON + for (; q + 1 < inh; q += 2) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val0; + float val1; + // if (elempack == 1) + { + val0 = r0[0]; + val1 = r0[N]; + r0 += dilation_w; + } + + sum += val0 * kptr[0]; + sum += val1 * kptr[1]; + + kptr += 2; + } + } + for (; q < inh; q++) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val; + // if (elempack == 1) + { + val = r0[0]; + r0 += dilation_w; + } + + sum += val * kptr[0]; + + kptr += 1; + } + } + + sum = activation_ss(sum, activation_type, activation_params); + + outptr[0] = sum; + outptr += 1; + } + } +} diff --git a/src/layer/arm/convolution1d_packed_bf16s.h b/src/layer/arm/convolution1d_packed_bf16s.h new file mode 100644 index 000000000..4494cc4f2 --- /dev/null +++ b/src/layer/arm/convolution1d_packed_bf16s.h @@ -0,0 +1,1311 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void convolution1d_transform_kernel_packed_bf16s(const Mat& kernel, Mat& kernel_tm, int inh, int outh, int kernel_w) +{ + // src = kw-inh-outh + // dst = pb-pa-kw-inh/pa-outh/pb + + // clang-format off + // *INDENT-OFF* +#if __ARM_NEON +#if __aarch64__ + if (outh >= 8) + { + if (inh >= 8) + kernel_tm.create(8 * 8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2, (size_t)2u); + else if (inh >= 4) + kernel_tm.create(8 * 4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2, (size_t)2u); + else if (inh >= 2) + kernel_tm.create(8 * 2 * kernel_w, inh / 2 + inh % 2, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2, (size_t)2u); + else + kernel_tm.create(8 * kernel_w, inh, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2, (size_t)2u); + } + else +#endif // __aarch64__ + if (outh >= 4) + { +#if __aarch64__ + if (inh >= 8) + kernel_tm.create(4 * 8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 4 + (outh % 4) / 2 + outh % 2, (size_t)2u); + else +#endif // __aarch64__ + if (inh >= 4) + kernel_tm.create(4 * 4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh / 4 + (outh % 4) / 2 + outh % 2, (size_t)2u); + else if (inh >= 2) + kernel_tm.create(4 * 2 * kernel_w, inh / 2 + inh % 2, outh / 4 + (outh % 4) / 2 + outh % 2, (size_t)2u); + else + kernel_tm.create(4 * kernel_w, inh, outh / 4 + (outh % 4) / 2 + outh % 2, (size_t)2u); + } + else +#endif // __ARM_NEON + if (outh >= 2) + { +#if __ARM_NEON +#if __aarch64__ + if (inh >= 8) + kernel_tm.create(2 * 8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 2 + outh % 2, (size_t)2u); + else +#endif // __aarch64__ + if (inh >= 4) + kernel_tm.create(2 * 4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh / 2 + outh % 2, (size_t)2u); + else if (inh >= 2) + kernel_tm.create(2 * 2 * kernel_w, inh / 2 + inh % 2, outh / 2 + outh % 2, (size_t)2u); + else +#endif // __ARM_NEON + kernel_tm.create(2 * kernel_w, inh, outh / 2 + outh % 2, (size_t)2u); + } + else + { +#if __ARM_NEON +#if __aarch64__ + if (inh >= 8) + kernel_tm.create(8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh, (size_t)2u); + else +#endif // __aarch64__ + if (inh >= 4) + kernel_tm.create(4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh, (size_t)2u); + else if (inh >= 2) + kernel_tm.create(2 * kernel_w, inh / 2 + inh % 2, outh, (size_t)2u); + else +#endif // __ARM_NEON + kernel_tm.create(kernel_w, inh, outh, (size_t)2u); + } + // *INDENT-ON* + // clang-format on + + int q = 0; +#if __ARM_NEON +#if __aarch64__ + for (; q + 7 < outh; q += 8) + { + const float* kptr0 = (const float*)kernel + q * inh * kernel_w; + const float* kptr1 = (const float*)kernel + (q + 1) * inh * kernel_w; + const float* kptr2 = (const float*)kernel + (q + 2) * inh * kernel_w; + const float* kptr3 = (const float*)kernel + (q + 3) * inh * kernel_w; + const float* kptr4 = (const float*)kernel + (q + 4) * inh * kernel_w; + const float* kptr5 = (const float*)kernel + (q + 5) * inh * kernel_w; + const float* kptr6 = (const float*)kernel + (q + 6) * inh * kernel_w; + const float* kptr7 = (const float*)kernel + (q + 7) * inh * kernel_w; + + unsigned short* g00 = kernel_tm.channel(q / 8); + + int p = 0; + for (; p + 7 < inh; p += 8) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + + for (int i = 0; i < 8; i++) + { + g00[0] = float32_to_bfloat16(k0[k]); + g00[1] = float32_to_bfloat16(k1[k]); + g00[2] = float32_to_bfloat16(k2[k]); + g00[3] = float32_to_bfloat16(k3[k]); + g00[4] = float32_to_bfloat16(k4[k]); + g00[5] = float32_to_bfloat16(k5[k]); + g00[6] = float32_to_bfloat16(k6[k]); + g00[7] = float32_to_bfloat16(k7[k]); + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + k4 += kernel_w; + k5 += kernel_w; + k6 += kernel_w; + k7 += kernel_w; + g00 += 8; + } + } + } + for (; p + 3 < inh; p += 4) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + + for (int i = 0; i < 4; i++) + { + g00[0] = float32_to_bfloat16(k0[k]); + g00[1] = float32_to_bfloat16(k1[k]); + g00[2] = float32_to_bfloat16(k2[k]); + g00[3] = float32_to_bfloat16(k3[k]); + g00[4] = float32_to_bfloat16(k4[k]); + g00[5] = float32_to_bfloat16(k5[k]); + g00[6] = float32_to_bfloat16(k6[k]); + g00[7] = float32_to_bfloat16(k7[k]); + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + k4 += kernel_w; + k5 += kernel_w; + k6 += kernel_w; + k7 += kernel_w; + g00 += 8; + } + } + } + for (; p + 1 < inh; p += 2) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + + for (int i = 0; i < 2; i++) + { + g00[0] = float32_to_bfloat16(k0[k]); + g00[1] = float32_to_bfloat16(k1[k]); + g00[2] = float32_to_bfloat16(k2[k]); + g00[3] = float32_to_bfloat16(k3[k]); + g00[4] = float32_to_bfloat16(k4[k]); + g00[5] = float32_to_bfloat16(k5[k]); + g00[6] = float32_to_bfloat16(k6[k]); + g00[7] = float32_to_bfloat16(k7[k]); + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + k4 += kernel_w; + k5 += kernel_w; + k6 += kernel_w; + k7 += kernel_w; + g00 += 8; + } + } + } + for (; p < inh; p++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + + for (int k = 0; k < kernel_w; k++) + { + g00[0] = float32_to_bfloat16(k0[k]); + g00[1] = float32_to_bfloat16(k1[k]); + g00[2] = float32_to_bfloat16(k2[k]); + g00[3] = float32_to_bfloat16(k3[k]); + g00[4] = float32_to_bfloat16(k4[k]); + g00[5] = float32_to_bfloat16(k5[k]); + g00[6] = float32_to_bfloat16(k6[k]); + g00[7] = float32_to_bfloat16(k7[k]); + g00 += 8; + } + } + } +#endif // __aarch64__ + for (; q + 3 < outh; q += 4) + { + const float* kptr0 = (const float*)kernel + q * inh * kernel_w; + const float* kptr1 = (const float*)kernel + (q + 1) * inh * kernel_w; + const float* kptr2 = (const float*)kernel + (q + 2) * inh * kernel_w; + const float* kptr3 = (const float*)kernel + (q + 3) * inh * kernel_w; + +#if __aarch64__ + unsigned short* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4); +#else + unsigned short* g00 = kernel_tm.channel(q / 4); +#endif + + int p = 0; +#if __aarch64__ + for (; p + 7 < inh; p += 8) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + + for (int i = 0; i < 8; i++) + { + g00[0] = float32_to_bfloat16(k0[k]); + g00[1] = float32_to_bfloat16(k1[k]); + g00[2] = float32_to_bfloat16(k2[k]); + g00[3] = float32_to_bfloat16(k3[k]); + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + g00 += 4; + } + } + } +#endif // __aarch64__ + for (; p + 3 < inh; p += 4) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + + for (int i = 0; i < 4; i++) + { + g00[0] = float32_to_bfloat16(k0[k]); + g00[1] = float32_to_bfloat16(k1[k]); + g00[2] = float32_to_bfloat16(k2[k]); + g00[3] = float32_to_bfloat16(k3[k]); + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + g00 += 4; + } + } + } + for (; p + 1 < inh; p += 2) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + + for (int i = 0; i < 2; i++) + { + g00[0] = float32_to_bfloat16(k0[k]); + g00[1] = float32_to_bfloat16(k1[k]); + g00[2] = float32_to_bfloat16(k2[k]); + g00[3] = float32_to_bfloat16(k3[k]); + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + g00 += 4; + } + } + } + for (; p < inh; p++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + + for (int k = 0; k < kernel_w; k++) + { + g00[0] = float32_to_bfloat16(k0[k]); + g00[1] = float32_to_bfloat16(k1[k]); + g00[2] = float32_to_bfloat16(k2[k]); + g00[3] = float32_to_bfloat16(k3[k]); + g00 += 4; + } + } + } +#endif // __ARM_NEON + for (; q + 1 < outh; q += 2) + { + const float* kptr0 = (const float*)kernel + q * inh * kernel_w; + const float* kptr1 = (const float*)kernel + (q + 1) * inh * kernel_w; + +#if __aarch64__ + unsigned short* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4 + (q % 4) / 2); +#elif __ARM_NEON + unsigned short* g00 = kernel_tm.channel(q / 4 + (q % 4) / 2); +#else + unsigned short* g00 = kernel_tm.channel(q / 2); +#endif + + int p = 0; +#if __ARM_NEON +#if __aarch64__ + for (; p + 7 < inh; p += 8) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w + k; + const float* k1 = kptr1 + p * kernel_w + k; + + g00[0] = float32_to_bfloat16(k0[0]); + g00[1] = float32_to_bfloat16(k0[kernel_w]); + g00[2] = float32_to_bfloat16(k0[kernel_w * 2]); + g00[3] = float32_to_bfloat16(k0[kernel_w * 3]); + g00[4] = float32_to_bfloat16(k0[kernel_w * 4]); + g00[5] = float32_to_bfloat16(k0[kernel_w * 5]); + g00[6] = float32_to_bfloat16(k0[kernel_w * 6]); + g00[7] = float32_to_bfloat16(k0[kernel_w * 7]); + g00[8] = float32_to_bfloat16(k1[0]); + g00[9] = float32_to_bfloat16(k1[kernel_w]); + g00[10] = float32_to_bfloat16(k1[kernel_w * 2]); + g00[11] = float32_to_bfloat16(k1[kernel_w * 3]); + g00[12] = float32_to_bfloat16(k1[kernel_w * 4]); + g00[13] = float32_to_bfloat16(k1[kernel_w * 5]); + g00[14] = float32_to_bfloat16(k1[kernel_w * 6]); + g00[15] = float32_to_bfloat16(k1[kernel_w * 7]); + g00 += 16; + } + } +#endif // __aarch64__ + for (; p + 3 < inh; p += 4) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w + k; + const float* k1 = kptr1 + p * kernel_w + k; + + g00[0] = float32_to_bfloat16(k0[0]); + g00[1] = float32_to_bfloat16(k0[kernel_w]); + g00[2] = float32_to_bfloat16(k0[kernel_w * 2]); + g00[3] = float32_to_bfloat16(k0[kernel_w * 3]); + g00[4] = float32_to_bfloat16(k1[0]); + g00[5] = float32_to_bfloat16(k1[kernel_w]); + g00[6] = float32_to_bfloat16(k1[kernel_w * 2]); + g00[7] = float32_to_bfloat16(k1[kernel_w * 3]); + g00 += 8; + } + } +#endif // __ARM_NEON + for (; p + 1 < inh; p += 2) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + + for (int i = 0; i < 2; i++) + { + g00[0] = float32_to_bfloat16(k0[k]); + g00[1] = float32_to_bfloat16(k1[k]); + k0 += kernel_w; + k1 += kernel_w; + g00 += 2; + } + } + } + for (; p < inh; p++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + + for (int k = 0; k < kernel_w; k++) + { + g00[0] = float32_to_bfloat16(k0[k]); + g00[1] = float32_to_bfloat16(k1[k]); + g00 += 2; + } + } + } + for (; q < outh; q++) + { + const float* kptr = (const float*)kernel + q * inh * kernel_w; + +#if __aarch64__ + unsigned short* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4 + (q % 4) / 2 + q % 2); +#elif __ARM_NEON + unsigned short* g00 = kernel_tm.channel(q / 4 + (q % 4) / 2 + q % 2); +#else + unsigned short* g00 = kernel_tm.channel(q / 2 + q % 2); +#endif + + int p = 0; +#if __ARM_NEON +#if __aarch64__ + for (; p + 7 < inh; p += 8) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr + p * kernel_w; + + for (int i = 0; i < 8; i++) + { + g00[0] = float32_to_bfloat16(k0[k]); + k0 += kernel_w; + g00 += 1; + } + } + } +#endif // __aarch64__ + for (; p + 3 < inh; p += 4) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr + p * kernel_w; + + for (int i = 0; i < 4; i++) + { + g00[0] = float32_to_bfloat16(k0[k]); + k0 += kernel_w; + g00 += 1; + } + } + } +#endif // __ARM_NEON + for (; p + 1 < inh; p += 2) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr + p * kernel_w; + + for (int i = 0; i < 2; i++) + { + g00[0] = float32_to_bfloat16(k0[k]); + k0 += kernel_w; + g00 += 1; + } + } + } + for (; p < inh; p++) + { + const float* k0 = kptr + p * kernel_w; + + for (int k = 0; k < kernel_w; k++) + { + g00[0] = float32_to_bfloat16(k0[k]); + g00++; + } + } + } +} + +static void convolution1d_packed_bf16s(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, const Mat& bias_data, int kernel_w, int dilation_w, int stride_w, int activation_type, const Mat& activation_params, const Option& opt) +{ + const int elempack = bottom_blob.elempack; + const int inh = bottom_blob.h * elempack; + + const int N = bottom_blob.w * elempack; + + const int outw = top_blob.w; + const int out_elempack = top_blob.elempack; + const int outh = top_blob.h * out_elempack; + + const int M = top_blob.w * out_elempack; + + const float* bias_data_ptr = bias_data; + + int nn_outh = 0; + int remain_outh_start = 0; +#if __ARM_NEON +#if __aarch64__ + nn_outh = (outh - remain_outh_start) / 8; + #pragma omp parallel for num_threads(opt.num_threads) + for (int pp = 0; pp < nn_outh; pp++) + { + const int p = remain_outh_start + pp * 8; + + unsigned short* outptr = top_blob.row(p / out_elempack); + + for (int j = 0; j < outw; j++) + { + float32x4_t _sum0 = vdupq_n_f32(0.f); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + float32x4_t _sum4 = vdupq_n_f32(0.f); + float32x4_t _sum5 = vdupq_n_f32(0.f); + float32x4_t _sum6 = vdupq_n_f32(0.f); + float32x4_t _sum7 = vdupq_n_f32(0.f); + + if (bias_data_ptr) + { + _sum0 = vld1q_f32(bias_data_ptr + p); + _sum1 = vld1q_f32(bias_data_ptr + p + 4); + } + + const unsigned short* kptr = weight_data_tm.channel(p / 8); + + int q = 0; + for (; q + 7 < inh; q += 8) + { + const unsigned short* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + float32x4_t _r1; + if (elempack == 4) + { + _r0 = bfloat2float(vld1_u16(r0)); + _r1 = bfloat2float(vld1_u16(r0 + N)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + uint16x8_t _r_u16 = uint16x8_t(); + _r_u16 = vsetq_lane_u16(r0[0], _r_u16, 0); + _r_u16 = vsetq_lane_u16(r0[N], _r_u16, 1); + _r_u16 = vsetq_lane_u16(r0[N * 2], _r_u16, 2); + _r_u16 = vsetq_lane_u16(r0[N * 3], _r_u16, 3); + _r_u16 = vsetq_lane_u16(r0[N * 4], _r_u16, 4); + _r_u16 = vsetq_lane_u16(r0[N * 5], _r_u16, 5); + _r_u16 = vsetq_lane_u16(r0[N * 6], _r_u16, 6); + _r_u16 = vsetq_lane_u16(r0[N * 7], _r_u16, 7); + _r0 = bfloat2float(vget_low_u16(_r_u16)); + _r1 = bfloat2float(vget_high_u16(_r_u16)); + r0 += dilation_w; + } + + uint16x8_t _w01 = vld1q_u16(kptr); + uint16x8_t _w23 = vld1q_u16(kptr + 8); + uint16x8_t _w45 = vld1q_u16(kptr + 16); + uint16x8_t _w67 = vld1q_u16(kptr + 24); + uint16x8_t _w89 = vld1q_u16(kptr + 32); + uint16x8_t _wab = vld1q_u16(kptr + 40); + uint16x8_t _wcd = vld1q_u16(kptr + 48); + uint16x8_t _wef = vld1q_u16(kptr + 56); + float32x4_t _w0 = bfloat2float(vget_low_u16(_w01)); + float32x4_t _w1 = bfloat2float(vget_high_u16(_w01)); + float32x4_t _w2 = bfloat2float(vget_low_u16(_w23)); + float32x4_t _w3 = bfloat2float(vget_high_u16(_w23)); + float32x4_t _w4 = bfloat2float(vget_low_u16(_w45)); + float32x4_t _w5 = bfloat2float(vget_high_u16(_w45)); + float32x4_t _w6 = bfloat2float(vget_low_u16(_w67)); + float32x4_t _w7 = bfloat2float(vget_high_u16(_w67)); + float32x4_t _w8 = bfloat2float(vget_low_u16(_w89)); + float32x4_t _w9 = bfloat2float(vget_high_u16(_w89)); + float32x4_t _wa = bfloat2float(vget_low_u16(_wab)); + float32x4_t _wb = bfloat2float(vget_high_u16(_wab)); + float32x4_t _wc = bfloat2float(vget_low_u16(_wcd)); + float32x4_t _wd = bfloat2float(vget_high_u16(_wcd)); + float32x4_t _we = bfloat2float(vget_low_u16(_wef)); + float32x4_t _wf = bfloat2float(vget_high_u16(_wef)); + _sum0 = vfmaq_laneq_f32(_sum0, _w0, _r0, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w1, _r0, 0); + _sum2 = vfmaq_laneq_f32(_sum2, _w2, _r0, 1); + _sum3 = vfmaq_laneq_f32(_sum3, _w3, _r0, 1); + _sum4 = vfmaq_laneq_f32(_sum4, _w4, _r0, 2); + _sum5 = vfmaq_laneq_f32(_sum5, _w5, _r0, 2); + _sum6 = vfmaq_laneq_f32(_sum6, _w6, _r0, 3); + _sum7 = vfmaq_laneq_f32(_sum7, _w7, _r0, 3); + _sum0 = vfmaq_laneq_f32(_sum0, _w8, _r1, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w9, _r1, 0); + _sum2 = vfmaq_laneq_f32(_sum2, _wa, _r1, 1); + _sum3 = vfmaq_laneq_f32(_sum3, _wb, _r1, 1); + _sum4 = vfmaq_laneq_f32(_sum4, _wc, _r1, 2); + _sum5 = vfmaq_laneq_f32(_sum5, _wd, _r1, 2); + _sum6 = vfmaq_laneq_f32(_sum6, _we, _r1, 3); + _sum7 = vfmaq_laneq_f32(_sum7, _wf, _r1, 3); + + kptr += 64; + } + } + for (; q + 3 < inh; q += 4) + { + const unsigned short* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + if (elempack == 4) + { + _r0 = bfloat2float(vld1_u16(r0)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + uint16x4_t _r_u16 = uint16x4_t(); + _r_u16 = vset_lane_u16(r0[0], _r_u16, 0); + _r_u16 = vset_lane_u16(r0[N], _r_u16, 1); + _r_u16 = vset_lane_u16(r0[N * 2], _r_u16, 2); + _r_u16 = vset_lane_u16(r0[N * 3], _r_u16, 3); + _r0 = bfloat2float(_r_u16); + r0 += dilation_w; + } + + uint16x8_t _w01 = vld1q_u16(kptr); + uint16x8_t _w23 = vld1q_u16(kptr + 8); + uint16x8_t _w45 = vld1q_u16(kptr + 16); + uint16x8_t _w67 = vld1q_u16(kptr + 24); + float32x4_t _w0 = bfloat2float(vget_low_u16(_w01)); + float32x4_t _w1 = bfloat2float(vget_high_u16(_w01)); + float32x4_t _w2 = bfloat2float(vget_low_u16(_w23)); + float32x4_t _w3 = bfloat2float(vget_high_u16(_w23)); + float32x4_t _w4 = bfloat2float(vget_low_u16(_w45)); + float32x4_t _w5 = bfloat2float(vget_high_u16(_w45)); + float32x4_t _w6 = bfloat2float(vget_low_u16(_w67)); + float32x4_t _w7 = bfloat2float(vget_high_u16(_w67)); + _sum0 = vfmaq_laneq_f32(_sum0, _w0, _r0, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w1, _r0, 0); + _sum2 = vfmaq_laneq_f32(_sum2, _w2, _r0, 1); + _sum3 = vfmaq_laneq_f32(_sum3, _w3, _r0, 1); + _sum4 = vfmaq_laneq_f32(_sum4, _w4, _r0, 2); + _sum5 = vfmaq_laneq_f32(_sum5, _w5, _r0, 2); + _sum6 = vfmaq_laneq_f32(_sum6, _w6, _r0, 3); + _sum7 = vfmaq_laneq_f32(_sum7, _w7, _r0, 3); + + kptr += 32; + } + } + for (; q + 1 < inh; q += 2) + { + const unsigned short* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val0; + float val1; + // if (elempack == 1) + { + val0 = bfloat16_to_float32(r0[0]); + val1 = bfloat16_to_float32(r0[N]); + r0 += dilation_w; + } + + uint16x8_t _w01 = vld1q_u16(kptr); + uint16x8_t _w23 = vld1q_u16(kptr + 8); + float32x4_t _w0 = bfloat2float(vget_low_u16(_w01)); + float32x4_t _w1 = bfloat2float(vget_high_u16(_w01)); + float32x4_t _w2 = bfloat2float(vget_low_u16(_w23)); + float32x4_t _w3 = bfloat2float(vget_high_u16(_w23)); + _sum0 = vfmaq_n_f32(_sum0, _w0, val0); + _sum1 = vfmaq_n_f32(_sum1, _w1, val0); + _sum2 = vfmaq_n_f32(_sum2, _w2, val1); + _sum3 = vfmaq_n_f32(_sum3, _w3, val1); + + kptr += 16; + } + } + for (; q < inh; q++) + { + const unsigned short* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _val; + // if (elempack == 1) + { + _val = bfloat2float(vdup_n_u16(r0[0])); + r0 += dilation_w; + } + + uint16x8_t _w = vld1q_u16(kptr); + float32x4_t _w0 = bfloat2float(vget_low_u16(_w)); + float32x4_t _w1 = bfloat2float(vget_high_u16(_w)); + _sum0 = vfmaq_f32(_sum0, _w0, _val); + _sum1 = vfmaq_f32(_sum1, _w1, _val); + + kptr += 8; + } + } + + _sum0 = vaddq_f32(_sum0, _sum2); + _sum1 = vaddq_f32(_sum1, _sum3); + _sum4 = vaddq_f32(_sum4, _sum6); + _sum5 = vaddq_f32(_sum5, _sum7); + _sum0 = vaddq_f32(_sum0, _sum4); + _sum1 = vaddq_f32(_sum1, _sum5); + + _sum0 = activation_ps(_sum0, activation_type, activation_params); + _sum1 = activation_ps(_sum1, activation_type, activation_params); + + if (out_elempack == 4) + { + vst1_u16(outptr, float2bfloat(_sum0)); + vst1_u16(outptr + M, float2bfloat(_sum1)); + outptr += 4; + } + if (out_elempack == 1) + { + uint16x4_t _sum0_u16 = float2bfloat(_sum0); + uint16x4_t _sum1_u16 = float2bfloat(_sum1); + outptr[0] = vget_lane_u16(_sum0_u16, 0); + outptr[M] = vget_lane_u16(_sum0_u16, 1); + outptr[M * 2] = vget_lane_u16(_sum0_u16, 2); + outptr[M * 3] = vget_lane_u16(_sum0_u16, 3); + outptr[M * 4] = vget_lane_u16(_sum1_u16, 0); + outptr[M * 5] = vget_lane_u16(_sum1_u16, 1); + outptr[M * 6] = vget_lane_u16(_sum1_u16, 2); + outptr[M * 7] = vget_lane_u16(_sum1_u16, 3); + outptr += 1; + } + } + } + remain_outh_start += nn_outh * 8; + nn_outh = (outh - remain_outh_start) / 4; +#else // __aarch64__ + nn_outh = (outh - remain_outh_start) / 4; + #pragma omp parallel for num_threads(opt.num_threads) +#endif // __aarch64__ + for (int pp = 0; pp < nn_outh; pp++) + { + const int p = remain_outh_start + pp * 4; + + unsigned short* outptr = top_blob.row(p / out_elempack); + + for (int j = 0; j < outw; j++) + { + float32x4_t _sum0 = vdupq_n_f32(0.f); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + + if (bias_data_ptr) + { + _sum0 = vld1q_f32(bias_data_ptr + p); + } + +#if __aarch64__ + const unsigned short* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4); +#else + const unsigned short* kptr = weight_data_tm.channel(p / 4); +#endif + + int q = 0; +#if __aarch64__ + for (; q + 7 < inh; q += 8) + { + const unsigned short* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + float32x4_t _r1; + if (elempack == 4) + { + _r0 = bfloat2float(vld1_u16(r0)); + _r1 = bfloat2float(vld1_u16(r0 + N)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + uint16x8_t _r_u16 = uint16x8_t(); + _r_u16 = vsetq_lane_u16(r0[0], _r_u16, 0); + _r_u16 = vsetq_lane_u16(r0[N], _r_u16, 1); + _r_u16 = vsetq_lane_u16(r0[N * 2], _r_u16, 2); + _r_u16 = vsetq_lane_u16(r0[N * 3], _r_u16, 3); + _r_u16 = vsetq_lane_u16(r0[N * 4], _r_u16, 4); + _r_u16 = vsetq_lane_u16(r0[N * 5], _r_u16, 5); + _r_u16 = vsetq_lane_u16(r0[N * 6], _r_u16, 6); + _r_u16 = vsetq_lane_u16(r0[N * 7], _r_u16, 7); + _r0 = bfloat2float(vget_low_u16(_r_u16)); + _r1 = bfloat2float(vget_high_u16(_r_u16)); + r0 += dilation_w; + } + + uint16x8_t _w01 = vld1q_u16(kptr); + uint16x8_t _w23 = vld1q_u16(kptr + 8); + uint16x8_t _w45 = vld1q_u16(kptr + 16); + uint16x8_t _w67 = vld1q_u16(kptr + 24); + float32x4_t _w0 = bfloat2float(vget_low_u16(_w01)); + float32x4_t _w1 = bfloat2float(vget_high_u16(_w01)); + float32x4_t _w2 = bfloat2float(vget_low_u16(_w23)); + float32x4_t _w3 = bfloat2float(vget_high_u16(_w23)); + float32x4_t _w4 = bfloat2float(vget_low_u16(_w45)); + float32x4_t _w5 = bfloat2float(vget_high_u16(_w45)); + float32x4_t _w6 = bfloat2float(vget_low_u16(_w67)); + float32x4_t _w7 = bfloat2float(vget_high_u16(_w67)); + _sum0 = vfmaq_laneq_f32(_sum0, _w0, _r0, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w1, _r0, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _w2, _r0, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _w3, _r0, 3); + _sum0 = vfmaq_laneq_f32(_sum0, _w4, _r1, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w5, _r1, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _w6, _r1, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _w7, _r1, 3); + + kptr += 32; + } + } +#endif // __aarch64__ + for (; q + 3 < inh; q += 4) + { + const unsigned short* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + if (elempack == 4) + { + _r0 = bfloat2float(vld1_u16(r0)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + uint16x4_t _r_u16 = uint16x4_t(); + _r_u16 = vset_lane_u16(r0[0], _r_u16, 0); + _r_u16 = vset_lane_u16(r0[N], _r_u16, 1); + _r_u16 = vset_lane_u16(r0[N * 2], _r_u16, 2); + _r_u16 = vset_lane_u16(r0[N * 3], _r_u16, 3); + _r0 = bfloat2float(_r_u16); + r0 += dilation_w; + } + + uint16x8_t _w01 = vld1q_u16(kptr); + uint16x8_t _w23 = vld1q_u16(kptr + 8); + float32x4_t _w0 = bfloat2float(vget_low_u16(_w01)); + float32x4_t _w1 = bfloat2float(vget_high_u16(_w01)); + float32x4_t _w2 = bfloat2float(vget_low_u16(_w23)); + float32x4_t _w3 = bfloat2float(vget_high_u16(_w23)); +#if __aarch64__ + _sum0 = vfmaq_laneq_f32(_sum0, _w0, _r0, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w1, _r0, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _w2, _r0, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _w3, _r0, 3); +#else + _sum0 = vmlaq_lane_f32(_sum0, _w0, vget_low_f32(_r0), 0); + _sum1 = vmlaq_lane_f32(_sum1, _w1, vget_low_f32(_r0), 1); + _sum2 = vmlaq_lane_f32(_sum2, _w2, vget_high_f32(_r0), 0); + _sum3 = vmlaq_lane_f32(_sum3, _w3, vget_high_f32(_r0), 1); +#endif + + kptr += 16; + } + } + for (; q + 1 < inh; q += 2) + { + const unsigned short* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val0; + float val1; + // if (elempack == 1) + { + val0 = bfloat16_to_float32(r0[0]); + val1 = bfloat16_to_float32(r0[N]); + r0 += dilation_w; + } + + uint16x8_t _w = vld1q_u16(kptr); + float32x4_t _w0 = bfloat2float(vget_low_u16(_w)); + float32x4_t _w1 = bfloat2float(vget_high_u16(_w)); +#if __aarch64__ + _sum0 = vfmaq_n_f32(_sum0, _w0, val0); + _sum1 = vfmaq_n_f32(_sum1, _w1, val1); +#else + _sum0 = vmlaq_n_f32(_sum0, _w0, val0); + _sum1 = vmlaq_n_f32(_sum1, _w1, val1); +#endif + + kptr += 8; + } + } + for (; q < inh; q++) + { + const unsigned short* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _val; + // if (elempack == 1) + { + _val = bfloat2float(vdup_n_u16(r0[0])); + r0 += dilation_w; + } + + float32x4_t _w = bfloat2float(vld1_u16(kptr)); +#if __aarch64__ + _sum0 = vfmaq_f32(_sum0, _val, _w); +#else + _sum0 = vmlaq_f32(_sum0, _val, _w); +#endif + + kptr += 4; + } + } + + _sum0 = vaddq_f32(_sum0, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _sum0 = vaddq_f32(_sum0, _sum2); + + _sum0 = activation_ps(_sum0, activation_type, activation_params); + + if (out_elempack == 4) + { + vst1_u16(outptr, float2bfloat(_sum0)); + outptr += 4; + } + if (out_elempack == 1) + { + uint16x4_t _sum0_u16 = float2bfloat(_sum0); + outptr[0] = vget_lane_u16(_sum0_u16, 0); + outptr[M] = vget_lane_u16(_sum0_u16, 1); + outptr[M * 2] = vget_lane_u16(_sum0_u16, 2); + outptr[M * 3] = vget_lane_u16(_sum0_u16, 3); + outptr += 1; + } + } + } + remain_outh_start += nn_outh * 4; + nn_outh = (outh - remain_outh_start) / 2; +#else // __ARM_NEON + nn_outh = (outh - remain_outh_start) / 2; + #pragma omp parallel for num_threads(opt.num_threads) +#endif // __ARM_NEON + for (int pp = 0; pp < nn_outh; pp++) + { + const int p = remain_outh_start + pp * 2; + + unsigned short* outptr0 = top_blob.row(p); + unsigned short* outptr1 = top_blob.row(p + 1); + + for (int j = 0; j < outw; j++) + { + float sum0 = 0.f; + float sum1 = 0.f; + + if (bias_data_ptr) + { + sum0 = bias_data_ptr[p]; + sum1 = bias_data_ptr[p + 1]; + } + +#if __aarch64__ + const unsigned short* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4 + (p % 4) / 2); +#elif __ARM_NEON + const unsigned short* kptr = weight_data_tm.channel(p / 4 + (p % 4) / 2); +#else + const unsigned short* kptr = weight_data_tm.channel(p / 2); +#endif + + int q = 0; +#if __ARM_NEON +#if __aarch64__ + float32x4_t _sum0 = vdupq_n_f32(0.f); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + for (; q + 7 < inh; q += 8) + { + const unsigned short* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + float32x4_t _r1; + if (elempack == 4) + { + _r0 = bfloat2float(vld1_u16(r0)); + _r1 = bfloat2float(vld1_u16(r0 + N)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + uint16x8_t _r01_u16 = uint16x8_t(); + _r01_u16 = vsetq_lane_u16(r0[0], _r01_u16, 0); + _r01_u16 = vsetq_lane_u16(r0[N], _r01_u16, 1); + _r01_u16 = vsetq_lane_u16(r0[N * 2], _r01_u16, 2); + _r01_u16 = vsetq_lane_u16(r0[N * 3], _r01_u16, 3); + _r01_u16 = vsetq_lane_u16(r0[N * 4], _r01_u16, 4); + _r01_u16 = vsetq_lane_u16(r0[N * 5], _r01_u16, 5); + _r01_u16 = vsetq_lane_u16(r0[N * 6], _r01_u16, 6); + _r01_u16 = vsetq_lane_u16(r0[N * 7], _r01_u16, 7); + _r0 = bfloat2float(vget_low_u16(_r01_u16)); + _r1 = bfloat2float(vget_high_u16(_r01_u16)); + r0 += dilation_w; + } + + uint16x8_t _w01 = vld1q_u16(kptr); + uint16x8_t _w23 = vld1q_u16(kptr + 8); + float32x4_t _w0 = bfloat2float(vget_low_u16(_w01)); + float32x4_t _w1 = bfloat2float(vget_high_u16(_w01)); + float32x4_t _w2 = bfloat2float(vget_low_u16(_w23)); + float32x4_t _w3 = bfloat2float(vget_high_u16(_w23)); + _sum0 = vfmaq_f32(_sum0, _r0, _w0); + _sum1 = vfmaq_f32(_sum1, _r1, _w1); + _sum2 = vfmaq_f32(_sum2, _r0, _w2); + _sum3 = vfmaq_f32(_sum3, _r1, _w3); + + kptr += 16; + } + } + _sum0 = vaddq_f32(_sum0, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + sum0 += vaddvq_f32(_sum0); + sum1 += vaddvq_f32(_sum2); + _sum0 = vdupq_n_f32(0.f); + _sum1 = vdupq_n_f32(0.f); +#else // __aarch64__ + float32x4_t _sum0 = vdupq_n_f32(0.f); + float32x4_t _sum1 = vdupq_n_f32(0.f); +#endif // __aarch64__ + for (; q + 3 < inh; q += 4) + { + const unsigned short* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + if (elempack == 4) + { + _r0 = bfloat2float(vld1_u16(r0)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + uint16x4_t _r0_u16 = uint16x4_t(); + _r0_u16 = vset_lane_u16(r0[0], _r0_u16, 0); + _r0_u16 = vset_lane_u16(r0[N], _r0_u16, 1); + _r0_u16 = vset_lane_u16(r0[N * 2], _r0_u16, 2); + _r0_u16 = vset_lane_u16(r0[N * 3], _r0_u16, 3); + _r0 = bfloat2float(_r0_u16); + r0 += dilation_w; + } + + uint16x8_t _w = vld1q_u16(kptr); + float32x4_t _w0 = bfloat2float(vget_low_u16(_w)); + float32x4_t _w1 = bfloat2float(vget_high_u16(_w)); +#if __aarch64__ + _sum0 = vfmaq_f32(_sum0, _r0, _w0); + _sum1 = vfmaq_f32(_sum1, _r0, _w1); +#else + _sum0 = vmlaq_f32(_sum0, _r0, _w0); + _sum1 = vmlaq_f32(_sum1, _r0, _w1); +#endif + + kptr += 8; + } + } +#if __aarch64__ + sum0 += vaddvq_f32(_sum0); + sum1 += vaddvq_f32(_sum1); +#else + float32x2_t _ss0 = vadd_f32(vget_low_f32(_sum0), vget_high_f32(_sum0)); + float32x2_t _ss1 = vadd_f32(vget_low_f32(_sum1), vget_high_f32(_sum1)); + float32x2_t _ss = vpadd_f32(_ss0, _ss1); + sum0 += vget_lane_f32(_ss, 0); + sum1 += vget_lane_f32(_ss, 1); +#endif +#endif // __ARM_NEON + for (; q + 1 < inh; q += 2) + { + const unsigned short* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val0; + float val1; + // if (elempack == 1) + { + val0 = bfloat16_to_float32(r0[0]); + val1 = bfloat16_to_float32(r0[N]); + r0 += dilation_w; + } + + sum0 += val0 * bfloat16_to_float32(kptr[0]); + sum1 += val0 * bfloat16_to_float32(kptr[1]); + sum0 += val1 * bfloat16_to_float32(kptr[2]); + sum1 += val1 * bfloat16_to_float32(kptr[3]); + + kptr += 4; + } + } + for (; q < inh; q++) + { + const unsigned short* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val; + // if (elempack == 1) + { + val = bfloat16_to_float32(r0[0]); + r0 += dilation_w; + } + + sum0 += val * bfloat16_to_float32(kptr[0]); + sum1 += val * bfloat16_to_float32(kptr[1]); + + kptr += 2; + } + } + + sum0 = activation_ss(sum0, activation_type, activation_params); + sum1 = activation_ss(sum1, activation_type, activation_params); + + outptr0[0] = float32_to_bfloat16(sum0); + outptr1[0] = float32_to_bfloat16(sum1); + outptr0 += 1; + outptr1 += 1; + } + } + remain_outh_start += nn_outh * 2; + for (int p = remain_outh_start; p < outh; p++) + { + unsigned short* outptr = top_blob.row(p); + + for (int j = 0; j < outw; j++) + { + float sum = 0.f; + + if (bias_data_ptr) + { + sum = bias_data_ptr[p]; + } + +#if __aarch64__ + const unsigned short* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4 + (p % 4) / 2 + p % 2); +#elif __ARM_NEON + const unsigned short* kptr = weight_data_tm.channel(p / 4 + (p % 4) / 2 + p % 2); +#else + const unsigned short* kptr = weight_data_tm.channel(p / 2 + p % 2); +#endif + + int q = 0; +#if __ARM_NEON +#if __aarch64__ + float32x4_t _sum0 = vdupq_n_f32(0.f); + float32x4_t _sum1 = vdupq_n_f32(0.f); + for (; q + 7 < inh; q += 8) + { + const unsigned short* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + float32x4_t _r1; + if (elempack == 4) + { + _r0 = bfloat2float(vld1_u16(r0)); + _r1 = bfloat2float(vld1_u16(r0 + N)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + uint16x8_t _r01_u16 = uint16x8_t(); + _r01_u16 = vsetq_lane_u16(r0[0], _r01_u16, 0); + _r01_u16 = vsetq_lane_u16(r0[N], _r01_u16, 1); + _r01_u16 = vsetq_lane_u16(r0[N * 2], _r01_u16, 2); + _r01_u16 = vsetq_lane_u16(r0[N * 3], _r01_u16, 3); + _r01_u16 = vsetq_lane_u16(r0[N * 4], _r01_u16, 4); + _r01_u16 = vsetq_lane_u16(r0[N * 5], _r01_u16, 5); + _r01_u16 = vsetq_lane_u16(r0[N * 6], _r01_u16, 6); + _r01_u16 = vsetq_lane_u16(r0[N * 7], _r01_u16, 7); + _r0 = bfloat2float(vget_low_u16(_r01_u16)); + _r1 = bfloat2float(vget_high_u16(_r01_u16)); + r0 += dilation_w; + } + + uint16x8_t _w = vld1q_u16(kptr); + float32x4_t _w0 = bfloat2float(vget_low_u16(_w)); + float32x4_t _w1 = bfloat2float(vget_high_u16(_w)); + _sum0 = vfmaq_f32(_sum0, _r0, _w0); + _sum1 = vfmaq_f32(_sum1, _r1, _w1); + + kptr += 8; + } + } + _sum0 = vaddq_f32(_sum0, _sum1); + sum += vaddvq_f32(_sum0); +#endif // __aarch64__ + float32x4_t _sum = vdupq_n_f32(0.f); + for (; q + 3 < inh; q += 4) + { + const unsigned short* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + if (elempack == 4) + { + _r0 = bfloat2float(vld1_u16(r0)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + uint16x4_t _r0_u16 = uint16x4_t(); + _r0_u16 = vset_lane_u16(r0[0], _r0_u16, 0); + _r0_u16 = vset_lane_u16(r0[N], _r0_u16, 1); + _r0_u16 = vset_lane_u16(r0[N * 2], _r0_u16, 2); + _r0_u16 = vset_lane_u16(r0[N * 3], _r0_u16, 3); + _r0 = bfloat2float(_r0_u16); + r0 += dilation_w; + } + + float32x4_t _w = bfloat2float(vld1_u16(kptr)); +#if __aarch64__ + _sum = vfmaq_f32(_sum, _r0, _w); +#else + _sum = vmlaq_f32(_sum, _r0, _w); +#endif + + kptr += 4; + } + } +#if __aarch64__ + sum += vaddvq_f32(_sum); +#else + float32x2_t _ss = vadd_f32(vget_low_f32(_sum), vget_high_f32(_sum)); + _ss = vpadd_f32(_ss, _ss); + sum += vget_lane_f32(_ss, 0); +#endif +#endif // __ARM_NEON + for (; q + 1 < inh; q += 2) + { + const unsigned short* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val0; + float val1; + // if (elempack == 1) + { + val0 = bfloat16_to_float32(r0[0]); + val1 = bfloat16_to_float32(r0[N]); + r0 += dilation_w; + } + + sum += val0 * bfloat16_to_float32(kptr[0]); + sum += val1 * bfloat16_to_float32(kptr[1]); + + kptr += 2; + } + } + for (; q < inh; q++) + { + const unsigned short* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val; + // if (elempack == 1) + { + val = bfloat16_to_float32(r0[0]); + r0 += dilation_w; + } + + sum += val * bfloat16_to_float32(kptr[0]); + + kptr += 1; + } + } + + sum = activation_ss(sum, activation_type, activation_params); + + outptr[0] = float32_to_bfloat16(sum); + outptr += 1; + } + } +} diff --git a/src/layer/arm/convolution1d_packed_fp16s.h b/src/layer/arm/convolution1d_packed_fp16s.h new file mode 100644 index 000000000..680761575 --- /dev/null +++ b/src/layer/arm/convolution1d_packed_fp16s.h @@ -0,0 +1,1848 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void convolution1d_transform_kernel_packed_fp16s(const Mat& kernel, Mat& kernel_tm, int inh, int outh, int kernel_w) +{ + // src = kw-inh-outh + // dst = pb-pa-kw-inh/pa-outh/pb + + // clang-format off + // *INDENT-OFF* + if (outh >= 8) + { + if (inh >= 8) + kernel_tm.create(8 * 8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2, (size_t)2u); + else if (inh >= 4) + kernel_tm.create(8 * 4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2, (size_t)2u); + else if (inh >= 2) + kernel_tm.create(8 * 2 * kernel_w, inh / 2 + inh % 2, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2, (size_t)2u); + else + kernel_tm.create(8 * kernel_w, inh, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2, (size_t)2u); + } + else if (outh >= 4) + { + if (inh >= 8) + kernel_tm.create(4 * 8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 4 + (outh % 4) / 2 + outh % 2, (size_t)2u); + else if (inh >= 4) + kernel_tm.create(4 * 4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh / 4 + (outh % 4) / 2 + outh % 2, (size_t)2u); + else if (inh >= 2) + kernel_tm.create(4 * 2 * kernel_w, inh / 2 + inh % 2, outh / 4 + (outh % 4) / 2 + outh % 2, (size_t)2u); + else + kernel_tm.create(4 * kernel_w, inh, outh / 4 + (outh % 4) / 2 + outh % 2, (size_t)2u); + } + else if (outh >= 2) + { + if (inh >= 8) + kernel_tm.create(2 * 8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 2 + outh % 2, (size_t)2u); + else if (inh >= 4) + kernel_tm.create(2 * 4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh / 2 + outh % 2, (size_t)2u); + else if (inh >= 2) + kernel_tm.create(2 * 2 * kernel_w, inh / 2 + inh % 2, outh / 2 + outh % 2, (size_t)2u); + else + kernel_tm.create(2 * kernel_w, inh, outh / 2 + outh % 2, (size_t)2u); + } + else + { + if (inh >= 8) + kernel_tm.create(8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh, (size_t)2u); + else if (inh >= 4) + kernel_tm.create(4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh, (size_t)2u); + else if (inh >= 2) + kernel_tm.create(2 * kernel_w, inh / 2 + inh % 2, outh, (size_t)2u); + else + kernel_tm.create(kernel_w, inh, outh, (size_t)2u); + } + // *INDENT-ON* + // clang-format on + + int q = 0; + for (; q + 7 < outh; q += 8) + { + const float* kptr0 = (const float*)kernel + q * inh * kernel_w; + const float* kptr1 = (const float*)kernel + (q + 1) * inh * kernel_w; + const float* kptr2 = (const float*)kernel + (q + 2) * inh * kernel_w; + const float* kptr3 = (const float*)kernel + (q + 3) * inh * kernel_w; + const float* kptr4 = (const float*)kernel + (q + 4) * inh * kernel_w; + const float* kptr5 = (const float*)kernel + (q + 5) * inh * kernel_w; + const float* kptr6 = (const float*)kernel + (q + 6) * inh * kernel_w; + const float* kptr7 = (const float*)kernel + (q + 7) * inh * kernel_w; + + __fp16* g00 = kernel_tm.channel(q / 8); + + int p = 0; + for (; p + 7 < inh; p += 8) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + + for (int i = 0; i < 8; i++) + { + g00[0] = (__fp16)k0[k]; + g00[1] = (__fp16)k1[k]; + g00[2] = (__fp16)k2[k]; + g00[3] = (__fp16)k3[k]; + g00[4] = (__fp16)k4[k]; + g00[5] = (__fp16)k5[k]; + g00[6] = (__fp16)k6[k]; + g00[7] = (__fp16)k7[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + k4 += kernel_w; + k5 += kernel_w; + k6 += kernel_w; + k7 += kernel_w; + g00 += 8; + } + } + } + for (; p + 3 < inh; p += 4) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + + for (int i = 0; i < 4; i++) + { + g00[0] = (__fp16)k0[k]; + g00[1] = (__fp16)k1[k]; + g00[2] = (__fp16)k2[k]; + g00[3] = (__fp16)k3[k]; + g00[4] = (__fp16)k4[k]; + g00[5] = (__fp16)k5[k]; + g00[6] = (__fp16)k6[k]; + g00[7] = (__fp16)k7[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + k4 += kernel_w; + k5 += kernel_w; + k6 += kernel_w; + k7 += kernel_w; + g00 += 8; + } + } + } + for (; p + 1 < inh; p += 2) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + + for (int i = 0; i < 2; i++) + { + g00[0] = (__fp16)k0[k]; + g00[1] = (__fp16)k1[k]; + g00[2] = (__fp16)k2[k]; + g00[3] = (__fp16)k3[k]; + g00[4] = (__fp16)k4[k]; + g00[5] = (__fp16)k5[k]; + g00[6] = (__fp16)k6[k]; + g00[7] = (__fp16)k7[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + k4 += kernel_w; + k5 += kernel_w; + k6 += kernel_w; + k7 += kernel_w; + g00 += 8; + } + } + } + for (; p < inh; p++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + + for (int k = 0; k < kernel_w; k++) + { + g00[0] = (__fp16)k0[k]; + g00[1] = (__fp16)k1[k]; + g00[2] = (__fp16)k2[k]; + g00[3] = (__fp16)k3[k]; + g00[4] = (__fp16)k4[k]; + g00[5] = (__fp16)k5[k]; + g00[6] = (__fp16)k6[k]; + g00[7] = (__fp16)k7[k]; + g00 += 8; + } + } + } + for (; q + 3 < outh; q += 4) + { + const float* kptr0 = (const float*)kernel + q * inh * kernel_w; + const float* kptr1 = (const float*)kernel + (q + 1) * inh * kernel_w; + const float* kptr2 = (const float*)kernel + (q + 2) * inh * kernel_w; + const float* kptr3 = (const float*)kernel + (q + 3) * inh * kernel_w; + + __fp16* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4); + + int p = 0; + for (; p + 7 < inh; p += 8) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + + for (int i = 0; i < 8; i++) + { + g00[0] = (__fp16)k0[k]; + g00[1] = (__fp16)k1[k]; + g00[2] = (__fp16)k2[k]; + g00[3] = (__fp16)k3[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + g00 += 4; + } + } + } + for (; p + 3 < inh; p += 4) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + + for (int i = 0; i < 4; i++) + { + g00[0] = (__fp16)k0[k]; + g00[1] = (__fp16)k1[k]; + g00[2] = (__fp16)k2[k]; + g00[3] = (__fp16)k3[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + g00 += 4; + } + } + } + for (; p + 1 < inh; p += 2) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + + for (int i = 0; i < 2; i++) + { + g00[0] = (__fp16)k0[k]; + g00[1] = (__fp16)k1[k]; + g00[2] = (__fp16)k2[k]; + g00[3] = (__fp16)k3[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + g00 += 4; + } + } + } + for (; p < inh; p++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + + for (int k = 0; k < kernel_w; k++) + { + g00[0] = (__fp16)k0[k]; + g00[1] = (__fp16)k1[k]; + g00[2] = (__fp16)k2[k]; + g00[3] = (__fp16)k3[k]; + g00 += 4; + } + } + } + for (; q + 1 < outh; q += 2) + { + const float* kptr0 = (const float*)kernel + q * inh * kernel_w; + const float* kptr1 = (const float*)kernel + (q + 1) * inh * kernel_w; + + __fp16* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4 + (q % 4) / 2); + + int p = 0; + for (; p + 7 < inh; p += 8) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w + k; + const float* k1 = kptr1 + p * kernel_w + k; + + g00[0] = (__fp16)k0[0]; + g00[1] = (__fp16)k0[kernel_w]; + g00[2] = (__fp16)k0[kernel_w * 2]; + g00[3] = (__fp16)k0[kernel_w * 3]; + g00[4] = (__fp16)k0[kernel_w * 4]; + g00[5] = (__fp16)k0[kernel_w * 5]; + g00[6] = (__fp16)k0[kernel_w * 6]; + g00[7] = (__fp16)k0[kernel_w * 7]; + g00[8] = (__fp16)k1[0]; + g00[9] = (__fp16)k1[kernel_w]; + g00[10] = (__fp16)k1[kernel_w * 2]; + g00[11] = (__fp16)k1[kernel_w * 3]; + g00[12] = (__fp16)k1[kernel_w * 4]; + g00[13] = (__fp16)k1[kernel_w * 5]; + g00[14] = (__fp16)k1[kernel_w * 6]; + g00[15] = (__fp16)k1[kernel_w * 7]; + g00 += 16; + } + } + for (; p + 3 < inh; p += 4) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w + k; + const float* k1 = kptr1 + p * kernel_w + k; + + g00[0] = (__fp16)k0[0]; + g00[1] = (__fp16)k0[kernel_w]; + g00[2] = (__fp16)k0[kernel_w * 2]; + g00[3] = (__fp16)k0[kernel_w * 3]; + g00[4] = (__fp16)k1[0]; + g00[5] = (__fp16)k1[kernel_w]; + g00[6] = (__fp16)k1[kernel_w * 2]; + g00[7] = (__fp16)k1[kernel_w * 3]; + g00 += 8; + } + } + for (; p + 1 < inh; p += 2) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + + for (int i = 0; i < 2; i++) + { + g00[0] = (__fp16)k0[k]; + g00[1] = (__fp16)k1[k]; + k0 += kernel_w; + k1 += kernel_w; + g00 += 2; + } + } + } + for (; p < inh; p++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + + for (int k = 0; k < kernel_w; k++) + { + g00[0] = (__fp16)k0[k]; + g00[1] = (__fp16)k1[k]; + g00 += 2; + } + } + } + for (; q < outh; q++) + { + const float* kptr = (const float*)kernel + q * inh * kernel_w; + + __fp16* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4 + (q % 4) / 2 + q % 2); + + int p = 0; + for (; p + 7 < inh; p += 8) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr + p * kernel_w; + + for (int i = 0; i < 8; i++) + { + g00[0] = (__fp16)k0[k]; + k0 += kernel_w; + g00 += 1; + } + } + } + for (; p + 3 < inh; p += 4) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr + p * kernel_w; + + for (int i = 0; i < 4; i++) + { + g00[0] = (__fp16)k0[k]; + k0 += kernel_w; + g00 += 1; + } + } + } + for (; p + 1 < inh; p += 2) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr + p * kernel_w; + + for (int i = 0; i < 2; i++) + { + g00[0] = (__fp16)k0[k]; + k0 += kernel_w; + g00 += 1; + } + } + } + for (; p < inh; p++) + { + const float* k0 = kptr + p * kernel_w; + + for (int k = 0; k < kernel_w; k++) + { + g00[0] = (__fp16)k0[k]; + g00++; + } + } + } +} + +static void convolution1d_packed_fp16s(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, const Mat& bias_data, int kernel_w, int dilation_w, int stride_w, int activation_type, const Mat& activation_params, const Option& opt) +{ + const int elempack = bottom_blob.elempack; + const int inh = bottom_blob.h * elempack; + + const int N = bottom_blob.w * elempack; + + const int outw = top_blob.w; + const int out_elempack = top_blob.elempack; + const int outh = top_blob.h * out_elempack; + + const int M = top_blob.w * out_elempack; + + const float* bias_data_ptr = bias_data; + + int nn_outh = 0; + int remain_outh_start = 0; + nn_outh = (outh - remain_outh_start) / 8; + #pragma omp parallel for num_threads(opt.num_threads) + for (int pp = 0; pp < nn_outh; pp++) + { + const int p = remain_outh_start + pp * 8; + + __fp16* outptr = top_blob.row<__fp16>(p / out_elempack); + + for (int j = 0; j < outw; j++) + { + float32x4_t _sum0 = vdupq_n_f32(0.f); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + float32x4_t _sum4 = vdupq_n_f32(0.f); + float32x4_t _sum5 = vdupq_n_f32(0.f); + float32x4_t _sum6 = vdupq_n_f32(0.f); + float32x4_t _sum7 = vdupq_n_f32(0.f); + + if (bias_data_ptr) + { + _sum0 = vld1q_f32(bias_data_ptr + p); + _sum1 = vld1q_f32(bias_data_ptr + p + 4); + } + + const __fp16* kptr = weight_data_tm.channel(p / 8); + + int q = 0; + for (; q + 7 < inh; q += 8) + { + const __fp16* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + float32x4_t _r1; + if (elempack == 4) + { + _r0 = vcvt_f32_f16(vld1_f16(r0)); + _r1 = vcvt_f32_f16(vld1_f16(r0 + N)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + float16x8_t _r_f16 = float16x8_t(); + _r_f16 = vsetq_lane_f16(r0[0], _r_f16, 0); + _r_f16 = vsetq_lane_f16(r0[N], _r_f16, 1); + _r_f16 = vsetq_lane_f16(r0[N * 2], _r_f16, 2); + _r_f16 = vsetq_lane_f16(r0[N * 3], _r_f16, 3); + _r_f16 = vsetq_lane_f16(r0[N * 4], _r_f16, 4); + _r_f16 = vsetq_lane_f16(r0[N * 5], _r_f16, 5); + _r_f16 = vsetq_lane_f16(r0[N * 6], _r_f16, 6); + _r_f16 = vsetq_lane_f16(r0[N * 7], _r_f16, 7); + _r0 = vcvt_f32_f16(vget_low_f16(_r_f16)); + _r1 = vcvt_f32_f16(vget_high_f16(_r_f16)); + r0 += dilation_w; + } + + float16x8_t _w01 = vld1q_f16(kptr); + float16x8_t _w23 = vld1q_f16(kptr + 8); + float16x8_t _w45 = vld1q_f16(kptr + 16); + float16x8_t _w67 = vld1q_f16(kptr + 24); + float16x8_t _w89 = vld1q_f16(kptr + 32); + float16x8_t _wab = vld1q_f16(kptr + 40); + float16x8_t _wcd = vld1q_f16(kptr + 48); + float16x8_t _wef = vld1q_f16(kptr + 56); + float32x4_t _w0 = vcvt_f32_f16(vget_low_f16(_w01)); + float32x4_t _w1 = vcvt_f32_f16(vget_high_f16(_w01)); + float32x4_t _w2 = vcvt_f32_f16(vget_low_f16(_w23)); + float32x4_t _w3 = vcvt_f32_f16(vget_high_f16(_w23)); + float32x4_t _w4 = vcvt_f32_f16(vget_low_f16(_w45)); + float32x4_t _w5 = vcvt_f32_f16(vget_high_f16(_w45)); + float32x4_t _w6 = vcvt_f32_f16(vget_low_f16(_w67)); + float32x4_t _w7 = vcvt_f32_f16(vget_high_f16(_w67)); + float32x4_t _w8 = vcvt_f32_f16(vget_low_f16(_w89)); + float32x4_t _w9 = vcvt_f32_f16(vget_high_f16(_w89)); + float32x4_t _wa = vcvt_f32_f16(vget_low_f16(_wab)); + float32x4_t _wb = vcvt_f32_f16(vget_high_f16(_wab)); + float32x4_t _wc = vcvt_f32_f16(vget_low_f16(_wcd)); + float32x4_t _wd = vcvt_f32_f16(vget_high_f16(_wcd)); + float32x4_t _we = vcvt_f32_f16(vget_low_f16(_wef)); + float32x4_t _wf = vcvt_f32_f16(vget_high_f16(_wef)); + _sum0 = vfmaq_laneq_f32(_sum0, _w0, _r0, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w1, _r0, 0); + _sum2 = vfmaq_laneq_f32(_sum2, _w2, _r0, 1); + _sum3 = vfmaq_laneq_f32(_sum3, _w3, _r0, 1); + _sum4 = vfmaq_laneq_f32(_sum4, _w4, _r0, 2); + _sum5 = vfmaq_laneq_f32(_sum5, _w5, _r0, 2); + _sum6 = vfmaq_laneq_f32(_sum6, _w6, _r0, 3); + _sum7 = vfmaq_laneq_f32(_sum7, _w7, _r0, 3); + _sum0 = vfmaq_laneq_f32(_sum0, _w8, _r1, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w9, _r1, 0); + _sum2 = vfmaq_laneq_f32(_sum2, _wa, _r1, 1); + _sum3 = vfmaq_laneq_f32(_sum3, _wb, _r1, 1); + _sum4 = vfmaq_laneq_f32(_sum4, _wc, _r1, 2); + _sum5 = vfmaq_laneq_f32(_sum5, _wd, _r1, 2); + _sum6 = vfmaq_laneq_f32(_sum6, _we, _r1, 3); + _sum7 = vfmaq_laneq_f32(_sum7, _wf, _r1, 3); + + kptr += 64; + } + } + for (; q + 3 < inh; q += 4) + { + const __fp16* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + if (elempack == 4) + { + _r0 = vcvt_f32_f16(vld1_f16(r0)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + float16x4_t _r_f16 = float16x4_t(); + _r_f16 = vset_lane_f16(r0[0], _r_f16, 0); + _r_f16 = vset_lane_f16(r0[N], _r_f16, 1); + _r_f16 = vset_lane_f16(r0[N * 2], _r_f16, 2); + _r_f16 = vset_lane_f16(r0[N * 3], _r_f16, 3); + _r0 = vcvt_f32_f16(_r_f16); + r0 += dilation_w; + } + + float16x8_t _w01 = vld1q_f16(kptr); + float16x8_t _w23 = vld1q_f16(kptr + 8); + float16x8_t _w45 = vld1q_f16(kptr + 16); + float16x8_t _w67 = vld1q_f16(kptr + 24); + float32x4_t _w0 = vcvt_f32_f16(vget_low_f16(_w01)); + float32x4_t _w1 = vcvt_f32_f16(vget_high_f16(_w01)); + float32x4_t _w2 = vcvt_f32_f16(vget_low_f16(_w23)); + float32x4_t _w3 = vcvt_f32_f16(vget_high_f16(_w23)); + float32x4_t _w4 = vcvt_f32_f16(vget_low_f16(_w45)); + float32x4_t _w5 = vcvt_f32_f16(vget_high_f16(_w45)); + float32x4_t _w6 = vcvt_f32_f16(vget_low_f16(_w67)); + float32x4_t _w7 = vcvt_f32_f16(vget_high_f16(_w67)); + _sum0 = vfmaq_laneq_f32(_sum0, _w0, _r0, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w1, _r0, 0); + _sum2 = vfmaq_laneq_f32(_sum2, _w2, _r0, 1); + _sum3 = vfmaq_laneq_f32(_sum3, _w3, _r0, 1); + _sum4 = vfmaq_laneq_f32(_sum4, _w4, _r0, 2); + _sum5 = vfmaq_laneq_f32(_sum5, _w5, _r0, 2); + _sum6 = vfmaq_laneq_f32(_sum6, _w6, _r0, 3); + _sum7 = vfmaq_laneq_f32(_sum7, _w7, _r0, 3); + + kptr += 32; + } + } + for (; q + 1 < inh; q += 2) + { + const __fp16* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val0; + float val1; + // if (elempack == 1) + { + val0 = (float)(r0[0]); + val1 = (float)(r0[N]); + r0 += dilation_w; + } + + float16x8_t _w01 = vld1q_f16(kptr); + float16x8_t _w23 = vld1q_f16(kptr + 8); + float32x4_t _w0 = vcvt_f32_f16(vget_low_f16(_w01)); + float32x4_t _w1 = vcvt_f32_f16(vget_high_f16(_w01)); + float32x4_t _w2 = vcvt_f32_f16(vget_low_f16(_w23)); + float32x4_t _w3 = vcvt_f32_f16(vget_high_f16(_w23)); + _sum0 = vfmaq_n_f32(_sum0, _w0, val0); + _sum1 = vfmaq_n_f32(_sum1, _w1, val0); + _sum2 = vfmaq_n_f32(_sum2, _w2, val1); + _sum3 = vfmaq_n_f32(_sum3, _w3, val1); + + kptr += 16; + } + } + for (; q < inh; q++) + { + const __fp16* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _val; + // if (elempack == 1) + { + _val = vcvt_f32_f16(vdup_n_f16(r0[0])); + r0 += dilation_w; + } + + float16x8_t _w = vld1q_f16(kptr); + float32x4_t _w0 = vcvt_f32_f16(vget_low_f16(_w)); + float32x4_t _w1 = vcvt_f32_f16(vget_high_f16(_w)); + _sum0 = vfmaq_f32(_sum0, _w0, _val); + _sum1 = vfmaq_f32(_sum1, _w1, _val); + + kptr += 8; + } + } + + _sum0 = vaddq_f32(_sum0, _sum2); + _sum1 = vaddq_f32(_sum1, _sum3); + _sum4 = vaddq_f32(_sum4, _sum6); + _sum5 = vaddq_f32(_sum5, _sum7); + _sum0 = vaddq_f32(_sum0, _sum4); + _sum1 = vaddq_f32(_sum1, _sum5); + + _sum0 = activation_ps(_sum0, activation_type, activation_params); + _sum1 = activation_ps(_sum1, activation_type, activation_params); + + if (out_elempack == 4) + { + vst1_f16(outptr, vcvt_f16_f32(_sum0)); + vst1_f16(outptr + M, vcvt_f16_f32(_sum1)); + outptr += 4; + } + if (out_elempack == 1) + { + float16x4_t _sum0_f16 = vcvt_f16_f32(_sum0); + float16x4_t _sum1_f16 = vcvt_f16_f32(_sum1); + outptr[0] = vget_lane_f16(_sum0_f16, 0); + outptr[M] = vget_lane_f16(_sum0_f16, 1); + outptr[M * 2] = vget_lane_f16(_sum0_f16, 2); + outptr[M * 3] = vget_lane_f16(_sum0_f16, 3); + outptr[M * 4] = vget_lane_f16(_sum1_f16, 0); + outptr[M * 5] = vget_lane_f16(_sum1_f16, 1); + outptr[M * 6] = vget_lane_f16(_sum1_f16, 2); + outptr[M * 7] = vget_lane_f16(_sum1_f16, 3); + outptr += 1; + } + } + } + remain_outh_start += nn_outh * 8; + nn_outh = (outh - remain_outh_start) / 4; + for (int pp = 0; pp < nn_outh; pp++) + { + const int p = remain_outh_start + pp * 4; + + __fp16* outptr = top_blob.row<__fp16>(p / out_elempack); + + for (int j = 0; j < outw; j++) + { + float32x4_t _sum0 = vdupq_n_f32(0.f); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + + if (bias_data_ptr) + { + _sum0 = vld1q_f32(bias_data_ptr + p); + } + + const __fp16* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4); + + int q = 0; + for (; q + 7 < inh; q += 8) + { + const __fp16* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + float32x4_t _r1; + if (elempack == 4) + { + _r0 = vcvt_f32_f16(vld1_f16(r0)); + _r1 = vcvt_f32_f16(vld1_f16(r0 + N)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + float16x8_t _r_f16 = float16x8_t(); + _r_f16 = vsetq_lane_f16(r0[0], _r_f16, 0); + _r_f16 = vsetq_lane_f16(r0[N], _r_f16, 1); + _r_f16 = vsetq_lane_f16(r0[N * 2], _r_f16, 2); + _r_f16 = vsetq_lane_f16(r0[N * 3], _r_f16, 3); + _r_f16 = vsetq_lane_f16(r0[N * 4], _r_f16, 4); + _r_f16 = vsetq_lane_f16(r0[N * 5], _r_f16, 5); + _r_f16 = vsetq_lane_f16(r0[N * 6], _r_f16, 6); + _r_f16 = vsetq_lane_f16(r0[N * 7], _r_f16, 7); + _r0 = vcvt_f32_f16(vget_low_f16(_r_f16)); + _r1 = vcvt_f32_f16(vget_high_f16(_r_f16)); + r0 += dilation_w; + } + + float16x8_t _w01 = vld1q_f16(kptr); + float16x8_t _w23 = vld1q_f16(kptr + 8); + float16x8_t _w45 = vld1q_f16(kptr + 16); + float16x8_t _w67 = vld1q_f16(kptr + 24); + float32x4_t _w0 = vcvt_f32_f16(vget_low_f16(_w01)); + float32x4_t _w1 = vcvt_f32_f16(vget_high_f16(_w01)); + float32x4_t _w2 = vcvt_f32_f16(vget_low_f16(_w23)); + float32x4_t _w3 = vcvt_f32_f16(vget_high_f16(_w23)); + float32x4_t _w4 = vcvt_f32_f16(vget_low_f16(_w45)); + float32x4_t _w5 = vcvt_f32_f16(vget_high_f16(_w45)); + float32x4_t _w6 = vcvt_f32_f16(vget_low_f16(_w67)); + float32x4_t _w7 = vcvt_f32_f16(vget_high_f16(_w67)); + _sum0 = vfmaq_laneq_f32(_sum0, _w0, _r0, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w1, _r0, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _w2, _r0, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _w3, _r0, 3); + _sum0 = vfmaq_laneq_f32(_sum0, _w4, _r1, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w5, _r1, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _w6, _r1, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _w7, _r1, 3); + + kptr += 32; + } + } + for (; q + 3 < inh; q += 4) + { + const __fp16* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + if (elempack == 4) + { + _r0 = vcvt_f32_f16(vld1_f16(r0)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + float16x4_t _r_f16 = float16x4_t(); + _r_f16 = vset_lane_f16(r0[0], _r_f16, 0); + _r_f16 = vset_lane_f16(r0[N], _r_f16, 1); + _r_f16 = vset_lane_f16(r0[N * 2], _r_f16, 2); + _r_f16 = vset_lane_f16(r0[N * 3], _r_f16, 3); + _r0 = vcvt_f32_f16(_r_f16); + r0 += dilation_w; + } + + float16x8_t _w01 = vld1q_f16(kptr); + float16x8_t _w23 = vld1q_f16(kptr + 8); + float32x4_t _w0 = vcvt_f32_f16(vget_low_f16(_w01)); + float32x4_t _w1 = vcvt_f32_f16(vget_high_f16(_w01)); + float32x4_t _w2 = vcvt_f32_f16(vget_low_f16(_w23)); + float32x4_t _w3 = vcvt_f32_f16(vget_high_f16(_w23)); + _sum0 = vfmaq_laneq_f32(_sum0, _w0, _r0, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _w1, _r0, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _w2, _r0, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _w3, _r0, 3); + + kptr += 16; + } + } + for (; q + 1 < inh; q += 2) + { + const __fp16* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val0; + float val1; + // if (elempack == 1) + { + val0 = (float)(r0[0]); + val1 = (float)(r0[N]); + r0 += dilation_w; + } + + float16x8_t _w = vld1q_f16(kptr); + float32x4_t _w0 = vcvt_f32_f16(vget_low_f16(_w)); + float32x4_t _w1 = vcvt_f32_f16(vget_high_f16(_w)); + _sum0 = vfmaq_n_f32(_sum0, _w0, val0); + _sum1 = vfmaq_n_f32(_sum1, _w1, val1); + + kptr += 8; + } + } + for (; q < inh; q++) + { + const __fp16* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _val; + // if (elempack == 1) + { + _val = vcvt_f32_f16(vdup_n_f16(r0[0])); + r0 += dilation_w; + } + + float32x4_t _w = vcvt_f32_f16(vld1_f16(kptr)); + _sum0 = vfmaq_f32(_sum0, _val, _w); + + kptr += 4; + } + } + + _sum0 = vaddq_f32(_sum0, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _sum0 = vaddq_f32(_sum0, _sum2); + + _sum0 = activation_ps(_sum0, activation_type, activation_params); + + if (out_elempack == 4) + { + vst1_f16(outptr, vcvt_f16_f32(_sum0)); + outptr += 4; + } + if (out_elempack == 1) + { + float16x4_t _sum0_f16 = vcvt_f16_f32(_sum0); + outptr[0] = vget_lane_f16(_sum0_f16, 0); + outptr[M] = vget_lane_f16(_sum0_f16, 1); + outptr[M * 2] = vget_lane_f16(_sum0_f16, 2); + outptr[M * 3] = vget_lane_f16(_sum0_f16, 3); + outptr += 1; + } + } + } + remain_outh_start += nn_outh * 4; + nn_outh = (outh - remain_outh_start) / 2; + for (int pp = 0; pp < nn_outh; pp++) + { + const int p = remain_outh_start + pp * 2; + + __fp16* outptr0 = top_blob.row<__fp16>(p); + __fp16* outptr1 = top_blob.row<__fp16>(p + 1); + + for (int j = 0; j < outw; j++) + { + float sum0 = 0.f; + float sum1 = 0.f; + + if (bias_data_ptr) + { + sum0 = bias_data_ptr[p]; + sum1 = bias_data_ptr[p + 1]; + } + + const __fp16* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4 + (p % 4) / 2); + + int q = 0; + float32x4_t _sum0 = vdupq_n_f32(0.f); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + for (; q + 7 < inh; q += 8) + { + const __fp16* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + float32x4_t _r1; + if (elempack == 4) + { + _r0 = vcvt_f32_f16(vld1_f16(r0)); + _r1 = vcvt_f32_f16(vld1_f16(r0 + N)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + float16x8_t _r01_f16 = float16x8_t(); + _r01_f16 = vsetq_lane_f16(r0[0], _r01_f16, 0); + _r01_f16 = vsetq_lane_f16(r0[N], _r01_f16, 1); + _r01_f16 = vsetq_lane_f16(r0[N * 2], _r01_f16, 2); + _r01_f16 = vsetq_lane_f16(r0[N * 3], _r01_f16, 3); + _r01_f16 = vsetq_lane_f16(r0[N * 4], _r01_f16, 4); + _r01_f16 = vsetq_lane_f16(r0[N * 5], _r01_f16, 5); + _r01_f16 = vsetq_lane_f16(r0[N * 6], _r01_f16, 6); + _r01_f16 = vsetq_lane_f16(r0[N * 7], _r01_f16, 7); + _r0 = vcvt_f32_f16(vget_low_f16(_r01_f16)); + _r1 = vcvt_f32_f16(vget_high_f16(_r01_f16)); + r0 += dilation_w; + } + + float16x8_t _w01 = vld1q_f16(kptr); + float16x8_t _w23 = vld1q_f16(kptr + 8); + float32x4_t _w0 = vcvt_f32_f16(vget_low_f16(_w01)); + float32x4_t _w1 = vcvt_f32_f16(vget_high_f16(_w01)); + float32x4_t _w2 = vcvt_f32_f16(vget_low_f16(_w23)); + float32x4_t _w3 = vcvt_f32_f16(vget_high_f16(_w23)); + _sum0 = vfmaq_f32(_sum0, _r0, _w0); + _sum1 = vfmaq_f32(_sum1, _r1, _w1); + _sum2 = vfmaq_f32(_sum2, _r0, _w2); + _sum3 = vfmaq_f32(_sum3, _r1, _w3); + + kptr += 16; + } + } + _sum0 = vaddq_f32(_sum0, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + sum0 += vaddvq_f32(_sum0); + sum1 += vaddvq_f32(_sum2); + _sum0 = vdupq_n_f32(0.f); + _sum1 = vdupq_n_f32(0.f); + for (; q + 3 < inh; q += 4) + { + const __fp16* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + if (elempack == 4) + { + _r0 = vcvt_f32_f16(vld1_f16(r0)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + float16x4_t _r0_f16 = float16x4_t(); + _r0_f16 = vset_lane_f16(r0[0], _r0_f16, 0); + _r0_f16 = vset_lane_f16(r0[N], _r0_f16, 1); + _r0_f16 = vset_lane_f16(r0[N * 2], _r0_f16, 2); + _r0_f16 = vset_lane_f16(r0[N * 3], _r0_f16, 3); + _r0 = vcvt_f32_f16(_r0_f16); + r0 += dilation_w; + } + + float16x8_t _w = vld1q_f16(kptr); + float32x4_t _w0 = vcvt_f32_f16(vget_low_f16(_w)); + float32x4_t _w1 = vcvt_f32_f16(vget_high_f16(_w)); + _sum0 = vfmaq_f32(_sum0, _r0, _w0); + _sum1 = vfmaq_f32(_sum1, _r0, _w1); + + kptr += 8; + } + } + sum0 += vaddvq_f32(_sum0); + sum1 += vaddvq_f32(_sum1); + for (; q + 1 < inh; q += 2) + { + const __fp16* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val0; + float val1; + // if (elempack == 1) + { + val0 = (float)(r0[0]); + val1 = (float)(r0[N]); + r0 += dilation_w; + } + + sum0 += val0 * (float)(kptr[0]); + sum1 += val0 * (float)(kptr[1]); + sum0 += val1 * (float)(kptr[2]); + sum1 += val1 * (float)(kptr[3]); + + kptr += 4; + } + } + for (; q < inh; q++) + { + const __fp16* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val; + // if (elempack == 1) + { + val = (float)(r0[0]); + r0 += dilation_w; + } + + sum0 += val * (float)(kptr[0]); + sum1 += val * (float)(kptr[1]); + + kptr += 2; + } + } + + sum0 = activation_ss(sum0, activation_type, activation_params); + sum1 = activation_ss(sum1, activation_type, activation_params); + + outptr0[0] = (__fp16)(sum0); + outptr1[0] = (__fp16)(sum1); + outptr0 += 1; + outptr1 += 1; + } + } + remain_outh_start += nn_outh * 2; + for (int p = remain_outh_start; p < outh; p++) + { + __fp16* outptr = top_blob.row<__fp16>(p); + + for (int j = 0; j < outw; j++) + { + float sum = 0.f; + + if (bias_data_ptr) + { + sum = bias_data_ptr[p]; + } + + const __fp16* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4 + (p % 4) / 2 + p % 2); + + int q = 0; + float32x4_t _sum0 = vdupq_n_f32(0.f); + float32x4_t _sum1 = vdupq_n_f32(0.f); + for (; q + 7 < inh; q += 8) + { + const __fp16* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + float32x4_t _r1; + if (elempack == 4) + { + _r0 = vcvt_f32_f16(vld1_f16(r0)); + _r1 = vcvt_f32_f16(vld1_f16(r0 + N)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + float16x8_t _r01_f16 = float16x8_t(); + _r01_f16 = vsetq_lane_f16(r0[0], _r01_f16, 0); + _r01_f16 = vsetq_lane_f16(r0[N], _r01_f16, 1); + _r01_f16 = vsetq_lane_f16(r0[N * 2], _r01_f16, 2); + _r01_f16 = vsetq_lane_f16(r0[N * 3], _r01_f16, 3); + _r01_f16 = vsetq_lane_f16(r0[N * 4], _r01_f16, 4); + _r01_f16 = vsetq_lane_f16(r0[N * 5], _r01_f16, 5); + _r01_f16 = vsetq_lane_f16(r0[N * 6], _r01_f16, 6); + _r01_f16 = vsetq_lane_f16(r0[N * 7], _r01_f16, 7); + _r0 = vcvt_f32_f16(vget_low_f16(_r01_f16)); + _r1 = vcvt_f32_f16(vget_high_f16(_r01_f16)); + r0 += dilation_w; + } + + float16x8_t _w = vld1q_f16(kptr); + float32x4_t _w0 = vcvt_f32_f16(vget_low_f16(_w)); + float32x4_t _w1 = vcvt_f32_f16(vget_high_f16(_w)); + _sum0 = vfmaq_f32(_sum0, _r0, _w0); + _sum1 = vfmaq_f32(_sum1, _r1, _w1); + + kptr += 8; + } + } + _sum0 = vaddq_f32(_sum0, _sum1); + sum += vaddvq_f32(_sum0); + float32x4_t _sum = vdupq_n_f32(0.f); + for (; q + 3 < inh; q += 4) + { + const __fp16* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float32x4_t _r0; + if (elempack == 4) + { + _r0 = vcvt_f32_f16(vld1_f16(r0)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + float16x4_t _r0_f16 = float16x4_t(); + _r0_f16 = vset_lane_f16(r0[0], _r0_f16, 0); + _r0_f16 = vset_lane_f16(r0[N], _r0_f16, 1); + _r0_f16 = vset_lane_f16(r0[N * 2], _r0_f16, 2); + _r0_f16 = vset_lane_f16(r0[N * 3], _r0_f16, 3); + _r0 = vcvt_f32_f16(_r0_f16); + r0 += dilation_w; + } + + float32x4_t _w = vcvt_f32_f16(vld1_f16(kptr)); + _sum = vfmaq_f32(_sum, _r0, _w); + + kptr += 4; + } + } + sum += vaddvq_f32(_sum); + for (; q + 1 < inh; q += 2) + { + const __fp16* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val0; + float val1; + // if (elempack == 1) + { + val0 = (float)(r0[0]); + val1 = (float)(r0[N]); + r0 += dilation_w; + } + + sum += val0 * (float)(kptr[0]); + sum += val1 * (float)(kptr[1]); + + kptr += 2; + } + } + for (; q < inh; q++) + { + const __fp16* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float val; + // if (elempack == 1) + { + val = (float)(r0[0]); + r0 += dilation_w; + } + + sum += val * (float)(kptr[0]); + + kptr += 1; + } + } + + sum = activation_ss(sum, activation_type, activation_params); + + outptr[0] = (__fp16)(sum); + outptr += 1; + } + } +} + +static void convolution1d_packed_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, const Mat& bias_data, int kernel_w, int dilation_w, int stride_w, int activation_type, const Mat& activation_params, const Option& opt) +{ + const int elempack = bottom_blob.elempack; + const int inh = bottom_blob.h * elempack; + + const int N = bottom_blob.w * elempack; + + const int outw = top_blob.w; + const int out_elempack = top_blob.elempack; + const int outh = top_blob.h * out_elempack; + + const int M = top_blob.w * out_elempack; + + const __fp16* bias_data_ptr = bias_data; + + int nn_outh = 0; + int remain_outh_start = 0; + nn_outh = (outh - remain_outh_start) / 8; + #pragma omp parallel for num_threads(opt.num_threads) + for (int pp = 0; pp < nn_outh; pp++) + { + const int p = remain_outh_start + pp * 8; + + __fp16* outptr = top_blob.row<__fp16>(p / out_elempack); + + for (int j = 0; j < outw; j++) + { + float16x8_t _sum0 = vdupq_n_f16(0.f); + float16x8_t _sum1 = vdupq_n_f16(0.f); + float16x8_t _sum2 = vdupq_n_f16(0.f); + float16x8_t _sum3 = vdupq_n_f16(0.f); + + if (bias_data_ptr) + { + _sum0 = vld1q_f16(bias_data_ptr + p); + } + + const __fp16* kptr = weight_data_tm.channel(p / 8); + + int q = 0; + for (; q + 7 < inh; q += 8) + { + const __fp16* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float16x8_t _r0; + if (elempack == 8) + { + _r0 = vld1q_f16(r0); + r0 += dilation_w * 8; + } + if (elempack == 4) + { + _r0 = vcombine_f16(vld1_f16(r0), vld1_f16(r0 + N)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + _r0 = vsetq_lane_f16(r0[0], _r0, 0); + _r0 = vsetq_lane_f16(r0[N], _r0, 1); + _r0 = vsetq_lane_f16(r0[N * 2], _r0, 2); + _r0 = vsetq_lane_f16(r0[N * 3], _r0, 3); + _r0 = vsetq_lane_f16(r0[N * 4], _r0, 4); + _r0 = vsetq_lane_f16(r0[N * 5], _r0, 5); + _r0 = vsetq_lane_f16(r0[N * 6], _r0, 6); + _r0 = vsetq_lane_f16(r0[N * 7], _r0, 7); + r0 += dilation_w; + } + + float16x8_t _w0 = vld1q_f16(kptr); + float16x8_t _w1 = vld1q_f16(kptr + 8); + float16x8_t _w2 = vld1q_f16(kptr + 8 * 2); + float16x8_t _w3 = vld1q_f16(kptr + 8 * 3); + float16x8_t _w4 = vld1q_f16(kptr + 8 * 4); + float16x8_t _w5 = vld1q_f16(kptr + 8 * 5); + float16x8_t _w6 = vld1q_f16(kptr + 8 * 6); + float16x8_t _w7 = vld1q_f16(kptr + 8 * 7); + _sum0 = vfmaq_laneq_f16(_sum0, _w0, _r0, 0); + _sum1 = vfmaq_laneq_f16(_sum1, _w1, _r0, 1); + _sum2 = vfmaq_laneq_f16(_sum2, _w2, _r0, 2); + _sum3 = vfmaq_laneq_f16(_sum3, _w3, _r0, 3); + _sum0 = vfmaq_laneq_f16(_sum0, _w4, _r0, 4); + _sum1 = vfmaq_laneq_f16(_sum1, _w5, _r0, 5); + _sum2 = vfmaq_laneq_f16(_sum2, _w6, _r0, 6); + _sum3 = vfmaq_laneq_f16(_sum3, _w7, _r0, 7); + + kptr += 64; + } + } + for (; q + 3 < inh; q += 4) + { + const __fp16* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float16x4_t _r0; + if (elempack == 4) + { + _r0 = vld1_f16(r0); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + _r0 = float16x4_t(); + _r0 = vset_lane_f16(r0[0], _r0, 0); + _r0 = vset_lane_f16(r0[N], _r0, 1); + _r0 = vset_lane_f16(r0[N * 2], _r0, 2); + _r0 = vset_lane_f16(r0[N * 3], _r0, 3); + r0 += dilation_w; + } + + float16x8_t _w0 = vld1q_f16(kptr); + float16x8_t _w1 = vld1q_f16(kptr + 8); + float16x8_t _w2 = vld1q_f16(kptr + 8 * 2); + float16x8_t _w3 = vld1q_f16(kptr + 8 * 3); + _sum0 = vfmaq_lane_f16(_sum0, _w0, _r0, 0); + _sum1 = vfmaq_lane_f16(_sum1, _w1, _r0, 1); + _sum2 = vfmaq_lane_f16(_sum2, _w2, _r0, 2); + _sum3 = vfmaq_lane_f16(_sum3, _w3, _r0, 3); + + kptr += 32; + } + } + for (; q + 1 < inh; q += 2) + { + const __fp16* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + __fp16 val0; + __fp16 val1; + // if (elempack == 1) + { + val0 = r0[0]; + val1 = r0[N]; + r0 += dilation_w; + } + + float16x8_t _w0 = vld1q_f16(kptr); + float16x8_t _w1 = vld1q_f16(kptr + 8); + _sum0 = vfmaq_n_f16(_sum0, _w0, val0); + _sum1 = vfmaq_n_f16(_sum1, _w1, val1); + + kptr += 16; + } + } + for (; q < inh; q++) + { + const __fp16* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float16x8_t _val; + // if (elempack == 1) + { + _val = vdupq_n_f16(r0[0]); + r0 += dilation_w; + } + + float16x8_t _w0 = vld1q_f16(kptr); + _sum0 = vfmaq_f16(_sum0, _w0, _val); + + kptr += 8; + } + } + + _sum0 = vaddq_f16(_sum0, _sum1); + _sum2 = vaddq_f16(_sum2, _sum3); + _sum0 = vaddq_f16(_sum0, _sum2); + + _sum0 = activation_ps(_sum0, activation_type, activation_params); + + if (out_elempack == 8) + { + vst1q_f16(outptr, _sum0); + outptr += 8; + } + if (out_elempack == 4) + { + vst1_f16(outptr, vget_low_f16(_sum0)); + vst1_f16(outptr + M, vget_high_f16(_sum0)); + outptr += 4; + } + if (out_elempack == 1) + { + outptr[0] = vgetq_lane_f16(_sum0, 0); + outptr[M] = vgetq_lane_f16(_sum0, 1); + outptr[M * 2] = vgetq_lane_f16(_sum0, 2); + outptr[M * 3] = vgetq_lane_f16(_sum0, 3); + outptr[M * 4] = vgetq_lane_f16(_sum0, 4); + outptr[M * 5] = vgetq_lane_f16(_sum0, 5); + outptr[M * 6] = vgetq_lane_f16(_sum0, 6); + outptr[M * 7] = vgetq_lane_f16(_sum0, 7); + outptr += 1; + } + } + } + remain_outh_start += nn_outh * 8; + nn_outh = (outh - remain_outh_start) / 4; + for (int pp = 0; pp < nn_outh; pp++) + { + const int p = remain_outh_start + pp * 4; + + __fp16* outptr = top_blob.row<__fp16>(p / out_elempack); + + for (int j = 0; j < outw; j++) + { + float16x4_t _sum0 = vdup_n_f16(0.f); + float16x4_t _sum1 = vdup_n_f16(0.f); + float16x4_t _sum2 = vdup_n_f16(0.f); + float16x4_t _sum3 = vdup_n_f16(0.f); + + if (bias_data_ptr) + { + _sum0 = vld1_f16(bias_data_ptr + p); + } + + const __fp16* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4); + + int q = 0; + for (; q + 7 < inh; q += 8) + { + const __fp16* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float16x4_t _r0; + float16x4_t _r1; + if (elempack == 8) + { + float16x8_t _r01 = vld1q_f16(r0); + _r0 = vget_low_f16(_r01); + _r1 = vget_high_f16(_r01); + r0 += dilation_w * 8; + } + if (elempack == 4) + { + _r0 = vld1_f16(r0); + _r1 = vld1_f16(r0 + N); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + _r0 = float16x4_t(); + _r1 = float16x4_t(); + _r0 = vset_lane_f16(r0[0], _r0, 0); + _r0 = vset_lane_f16(r0[N], _r0, 1); + _r0 = vset_lane_f16(r0[N * 2], _r0, 2); + _r0 = vset_lane_f16(r0[N * 3], _r0, 3); + _r1 = vset_lane_f16(r0[N * 4], _r1, 0); + _r1 = vset_lane_f16(r0[N * 5], _r1, 1); + _r1 = vset_lane_f16(r0[N * 6], _r1, 2); + _r1 = vset_lane_f16(r0[N * 7], _r1, 3); + r0 += dilation_w; + } + + float16x4_t _w0 = vld1_f16(kptr); + float16x4_t _w1 = vld1_f16(kptr + 4); + float16x4_t _w2 = vld1_f16(kptr + 8); + float16x4_t _w3 = vld1_f16(kptr + 12); + float16x4_t _w4 = vld1_f16(kptr + 16); + float16x4_t _w5 = vld1_f16(kptr + 20); + float16x4_t _w6 = vld1_f16(kptr + 24); + float16x4_t _w7 = vld1_f16(kptr + 28); + _sum0 = vfma_lane_f16(_sum0, _w0, _r0, 0); + _sum1 = vfma_lane_f16(_sum1, _w1, _r0, 1); + _sum2 = vfma_lane_f16(_sum2, _w2, _r0, 2); + _sum3 = vfma_lane_f16(_sum3, _w3, _r0, 3); + _sum0 = vfma_lane_f16(_sum0, _w4, _r1, 0); + _sum1 = vfma_lane_f16(_sum1, _w5, _r1, 1); + _sum2 = vfma_lane_f16(_sum2, _w6, _r1, 2); + _sum3 = vfma_lane_f16(_sum3, _w7, _r1, 3); + + kptr += 32; + } + } + for (; q + 3 < inh; q += 4) + { + const __fp16* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float16x4_t _r0; + if (elempack == 4) + { + _r0 = vld1_f16(r0); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + _r0 = float16x4_t(); + _r0 = vset_lane_f16(r0[0], _r0, 0); + _r0 = vset_lane_f16(r0[N], _r0, 1); + _r0 = vset_lane_f16(r0[N * 2], _r0, 2); + _r0 = vset_lane_f16(r0[N * 3], _r0, 3); + r0 += dilation_w; + } + + float16x4_t _w0 = vld1_f16(kptr); + float16x4_t _w1 = vld1_f16(kptr + 4); + float16x4_t _w2 = vld1_f16(kptr + 8); + float16x4_t _w3 = vld1_f16(kptr + 12); + _sum0 = vfma_lane_f16(_sum0, _w0, _r0, 0); + _sum1 = vfma_lane_f16(_sum1, _w1, _r0, 1); + _sum2 = vfma_lane_f16(_sum2, _w2, _r0, 2); + _sum3 = vfma_lane_f16(_sum3, _w3, _r0, 3); + + kptr += 16; + } + } + for (; q + 1 < inh; q += 2) + { + const __fp16* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + __fp16 val0; + __fp16 val1; + // if (elempack == 1) + { + val0 = r0[0]; + val1 = r0[N]; + r0 += dilation_w; + } + + float16x4_t _w0 = vld1_f16(kptr); + float16x4_t _w1 = vld1_f16(kptr + 4); + _sum0 = vfma_n_f16(_sum0, _w0, val0); + _sum1 = vfma_n_f16(_sum1, _w1, val1); + + kptr += 8; + } + } + for (; q < inh; q++) + { + const __fp16* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + float16x4_t _val; + // if (elempack == 1) + { + _val = vdup_n_f16(r0[0]); + r0 += dilation_w; + } + + float16x4_t _w = vld1_f16(kptr); + _sum0 = vfma_f16(_sum0, _val, _w); + + kptr += 4; + } + } + + _sum0 = vadd_f16(_sum0, _sum1); + _sum2 = vadd_f16(_sum2, _sum3); + _sum0 = vadd_f16(_sum0, _sum2); + + _sum0 = activation_ps(_sum0, activation_type, activation_params); + + if (out_elempack == 4) + { + vst1_f16(outptr, _sum0); + outptr += 4; + } + if (out_elempack == 1) + { + outptr[0] = vget_lane_f16(_sum0, 0); + outptr[M] = vget_lane_f16(_sum0, 1); + outptr[M * 2] = vget_lane_f16(_sum0, 2); + outptr[M * 3] = vget_lane_f16(_sum0, 3); + outptr += 1; + } + } + } + remain_outh_start += nn_outh * 4; + nn_outh = (outh - remain_outh_start) / 2; + for (int pp = 0; pp < nn_outh; pp++) + { + const int p = remain_outh_start + pp * 2; + + __fp16* outptr0 = top_blob.row<__fp16>(p); + __fp16* outptr1 = top_blob.row<__fp16>(p + 1); + + for (int j = 0; j < outw; j++) + { + __fp16 sum0 = 0.f; + __fp16 sum1 = 0.f; + + if (bias_data_ptr) + { + sum0 = bias_data_ptr[p]; + sum1 = bias_data_ptr[p + 1]; + } + + const __fp16* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4 + (p % 4) / 2); + + int q = 0; + float16x8_t _sum0 = vdupq_n_f16(0.f); + float16x8_t _sum1 = vdupq_n_f16(0.f); + for (; q + 7 < inh; q += 8) + { + const __fp16* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float16x8_t _r0; + if (elempack == 8) + { + _r0 = vld1q_f16(r0); + r0 += dilation_w * 8; + } + if (elempack == 4) + { + _r0 = vcombine_f16(vld1_f16(r0), vld1_f16(r0 + N)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + _r0 = float16x8_t(); + _r0 = vsetq_lane_f16(r0[0], _r0, 0); + _r0 = vsetq_lane_f16(r0[N], _r0, 1); + _r0 = vsetq_lane_f16(r0[N * 2], _r0, 2); + _r0 = vsetq_lane_f16(r0[N * 3], _r0, 3); + _r0 = vsetq_lane_f16(r0[N * 4], _r0, 4); + _r0 = vsetq_lane_f16(r0[N * 5], _r0, 5); + _r0 = vsetq_lane_f16(r0[N * 6], _r0, 6); + _r0 = vsetq_lane_f16(r0[N * 7], _r0, 7); + r0 += dilation_w; + } + + float16x8_t _w0 = vld1q_f16(kptr); + float16x8_t _w1 = vld1q_f16(kptr + 8); + _sum0 = vfmaq_f16(_sum0, _r0, _w0); + _sum1 = vfmaq_f16(_sum1, _r0, _w1); + + kptr += 16; + } + } + for (; q + 3 < inh; q += 4) + { + const __fp16* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float16x4_t _r0; + if (elempack == 4) + { + _r0 = vld1_f16(r0); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + _r0 = float16x4_t(); + _r0 = vset_lane_f16(r0[0], _r0, 0); + _r0 = vset_lane_f16(r0[N], _r0, 1); + _r0 = vset_lane_f16(r0[N * 2], _r0, 2); + _r0 = vset_lane_f16(r0[N * 3], _r0, 3); + r0 += dilation_w; + } + + float16x4_t _w0 = vld1_f16(kptr); + float16x4_t _w1 = vld1_f16(kptr + 4); + _sum0 = vcombine_f16(vfma_f16(vget_low_f16(_sum0), _r0, _w0), vget_high_f16(_sum0)); + _sum1 = vcombine_f16(vfma_f16(vget_low_f16(_sum1), _r0, _w1), vget_high_f16(_sum1)); + + kptr += 8; + } + } + for (; q + 1 < inh; q += 2) + { + const __fp16* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + __fp16 val0; + __fp16 val1; + // if (elempack == 1) + { + val0 = r0[0]; + val1 = r0[N]; + r0 += dilation_w; + } + + sum0 += val0 * kptr[0]; + sum1 += val0 * kptr[1]; + sum0 += val1 * kptr[2]; + sum1 += val1 * kptr[3]; + + kptr += 4; + } + } + for (; q < inh; q++) + { + const __fp16* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + __fp16 val; + // if (elempack == 1) + { + val = r0[0]; + r0 += dilation_w; + } + + sum0 += val * kptr[0]; + sum1 += val * kptr[1]; + + kptr += 2; + } + } + + float16x4_t _ss0 = vadd_f16(vget_low_f16(_sum0), vget_high_f16(_sum0)); + float16x4_t _ss1 = vadd_f16(vget_low_f16(_sum1), vget_high_f16(_sum1)); + float16x4_t _ss = vpadd_f16(_ss0, _ss1); + _ss = vpadd_f16(_ss, _ss); + sum0 += vget_lane_f16(_ss, 0); + sum1 += vget_lane_f16(_ss, 1); + + sum0 = activation_ss(sum0, activation_type, activation_params); + sum1 = activation_ss(sum1, activation_type, activation_params); + + outptr0[0] = sum0; + outptr1[0] = sum1; + outptr0 += 1; + outptr1 += 1; + } + } + remain_outh_start += nn_outh * 2; + for (int p = remain_outh_start; p < outh; p++) + { + __fp16* outptr = top_blob.row<__fp16>(p); + + for (int j = 0; j < outw; j++) + { + __fp16 sum = 0.f; + + if (bias_data_ptr) + { + sum = bias_data_ptr[p]; + } + + const __fp16* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4 + (p % 4) / 2 + p % 2); + + int q = 0; + float16x8_t _sum = vdupq_n_f16(0.f); + for (; q + 7 < inh; q += 8) + { + const __fp16* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float16x8_t _r0; + if (elempack == 8) + { + _r0 = vld1q_f16(r0); + r0 += dilation_w * 8; + } + if (elempack == 4) + { + _r0 = vcombine_f16(vld1_f16(r0), vld1_f16(r0 + N)); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + _r0 = float16x8_t(); + _r0 = vsetq_lane_f16(r0[0], _r0, 0); + _r0 = vsetq_lane_f16(r0[N], _r0, 1); + _r0 = vsetq_lane_f16(r0[N * 2], _r0, 2); + _r0 = vsetq_lane_f16(r0[N * 3], _r0, 3); + _r0 = vsetq_lane_f16(r0[N * 4], _r0, 4); + _r0 = vsetq_lane_f16(r0[N * 5], _r0, 5); + _r0 = vsetq_lane_f16(r0[N * 6], _r0, 6); + _r0 = vsetq_lane_f16(r0[N * 7], _r0, 7); + r0 += dilation_w; + } + + float16x8_t _w0 = vld1q_f16(kptr); + _sum = vfmaq_f16(_sum, _r0, _w0); + + kptr += 8; + } + } + for (; q + 3 < inh; q += 4) + { + const __fp16* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + for (int k = 0; k < kernel_w; k++) + { + float16x4_t _r0; + if (elempack == 4) + { + _r0 = vld1_f16(r0); + r0 += dilation_w * 4; + } + if (elempack == 1) + { + _r0 = float16x4_t(); + _r0 = vset_lane_f16(r0[0], _r0, 0); + _r0 = vset_lane_f16(r0[N], _r0, 1); + _r0 = vset_lane_f16(r0[N * 2], _r0, 2); + _r0 = vset_lane_f16(r0[N * 3], _r0, 3); + r0 += dilation_w; + } + + float16x4_t _w = vld1_f16(kptr); + _sum = vcombine_f16(vfma_f16(vget_low_f16(_sum), _r0, _w), vget_high_f16(_sum)); + + kptr += 4; + } + } + for (; q + 1 < inh; q += 2) + { + const __fp16* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + __fp16 val0; + __fp16 val1; + // if (elempack == 1) + { + val0 = r0[0]; + val1 = r0[N]; + r0 += dilation_w; + } + + sum += val0 * kptr[0]; + sum += val1 * kptr[1]; + + kptr += 2; + } + } + for (; q < inh; q++) + { + const __fp16* r0 = bottom_blob.row(q) + j * stride_w; + + for (int k = 0; k < kernel_w; k++) + { + __fp16 val; + // if (elempack == 1) + { + val = r0[0]; + r0 += dilation_w; + } + + sum += val * kptr[0]; + + kptr += 1; + } + } + + float16x4_t _ss = vadd_f16(vget_low_f16(_sum), vget_high_f16(_sum)); + _ss = vpadd_f16(_ss, _ss); + _ss = vpadd_f16(_ss, _ss); + sum += vget_lane_f16(_ss, 0); + + sum = activation_ss(sum, activation_type, activation_params); + + outptr[0] = sum; + outptr += 1; + } + } +} diff --git a/src/layer/arm/convolution_packed_fp16s.h b/src/layer/arm/convolution_packed_fp16s.h index 5463cb22f..0956295a7 100644 --- a/src/layer/arm/convolution_packed_fp16s.h +++ b/src/layer/arm/convolution_packed_fp16s.h @@ -32,8 +32,7 @@ static void convolution_transform_kernel_packed_fp16s(const Mat& kernel, Mat& ke else kernel_tm.create(8 * maxk, inch, outch / 8 + (outch % 8) / 4 + (outch % 4) / 2 + outch % 2, (size_t)2u); } - else - if (outch >= 4) + else if (outch >= 4) { if (inch >= 8) kernel_tm.create(4 * 8 * maxk, inch / 8 + (inch % 8) / 4 + (inch % 4) / 2 + inch % 2, outch / 4 + (outch % 4) / 2 + outch % 2, (size_t)2u); @@ -44,8 +43,7 @@ static void convolution_transform_kernel_packed_fp16s(const Mat& kernel, Mat& ke else kernel_tm.create(4 * maxk, inch, outch / 4 + (outch % 4) / 2 + outch % 2, (size_t)2u); } - else - if (outch >= 2) + else if (outch >= 2) { if (inch >= 8) kernel_tm.create(2 * 8 * maxk, inch / 8 + (inch % 8) / 4 + (inch % 4) / 2 + inch % 2, outch / 2 + outch % 2, (size_t)2u); diff --git a/src/layer/x86/convolution1d_packed.h b/src/layer/x86/convolution1d_packed.h new file mode 100644 index 000000000..b17a90bee --- /dev/null +++ b/src/layer/x86/convolution1d_packed.h @@ -0,0 +1,2702 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void convolution1d_transform_kernel_packed(const Mat& kernel, Mat& kernel_tm, int inh, int outh, int kernel_w) +{ + // src = kw-inh-outh + // dst = pb-pa-kw-inh/pa-outh/pb + + // clang-format off + // *INDENT-OFF* +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (outh >= 16) + { + if (inh >= 16) + kernel_tm.create(16 * 16 * kernel_w, inh / 16 + (inh % 16) / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 16 + (outh % 16) / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2); + else if (inh >= 8) + kernel_tm.create(16 * 8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 16 + (outh % 16) / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2); + else if (inh >= 4) + kernel_tm.create(16 * 4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh / 16 + (outh % 16) / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2); + else if (inh >= 2) + kernel_tm.create(16 * 2 * kernel_w, inh / 2 + inh % 2, outh / 16 + (outh % 16) / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2); + else + kernel_tm.create(16 * kernel_w, inh, outh / 16 + (outh % 16) / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2); + } + else +#endif // __AVX512F__ + if (outh >= 8) + { +#if __AVX512F__ + if (inh >= 16) + kernel_tm.create(8 * 16 * kernel_w, inh / 16 + (inh % 16) / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2); + else +#endif // __AVX512F__ + if (inh >= 8) + kernel_tm.create(8 * 8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2); + else if (inh >= 4) + kernel_tm.create(8 * 4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2); + else if (inh >= 2) + kernel_tm.create(8 * 2 * kernel_w, inh / 2 + inh % 2, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2); + else + kernel_tm.create(8 * kernel_w, inh, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2); + } + else +#endif // __AVX__ + if (outh >= 4) + { +#if __AVX__ +#if __AVX512F__ + if (inh >= 16) + kernel_tm.create(4 * 16 * kernel_w, inh / 16 + (inh % 16) / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 4 + (outh % 4) / 2 + outh % 2); + else +#endif // __AVX512F__ + if (inh >= 8) + kernel_tm.create(4 * 8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 4 + (outh % 4) / 2 + outh % 2); + else +#endif // __AVX__ + if (inh >= 4) + kernel_tm.create(4 * 4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh / 4 + (outh % 4) / 2 + outh % 2); + else if (inh >= 2) + kernel_tm.create(4 * 2 * kernel_w, inh / 2 + inh % 2, outh / 4 + (outh % 4) / 2 + outh % 2); + else + kernel_tm.create(4 * kernel_w, inh, outh / 4 + (outh % 4) / 2 + outh % 2); + } + else +#endif // __SSE2__ + if (outh >= 2) + { +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (inh >= 16) + kernel_tm.create(2 * 16 * kernel_w, inh / 16 + (inh % 16) / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 2 + outh % 2); + else +#endif // __AVX512F__ + if (inh >= 8) + kernel_tm.create(2 * 8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 2 + outh % 2); + else +#endif // __AVX__ + if (inh >= 4) + kernel_tm.create(2 * 4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh / 2 + outh % 2); + else if (inh >= 2) + kernel_tm.create(2 * 2 * kernel_w, inh / 2 + inh % 2, outh / 2 + outh % 2); + else +#endif // __SSE2__ + kernel_tm.create(2 * kernel_w, inh, outh / 2 + outh % 2); + } + else + { +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (inh >= 16) + kernel_tm.create(16 * kernel_w, inh / 16 + (inh % 16) / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh); + else +#endif // __AVX512F__ + if (inh >= 8) + kernel_tm.create(8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh); + else +#endif // __AVX__ + if (inh >= 4) + kernel_tm.create(4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh); + else if (inh >= 2) + kernel_tm.create(2 * kernel_w, inh / 2 + inh % 2, outh); + else +#endif // __SSE2__ + kernel_tm.create(kernel_w, inh, outh); + } + // *INDENT-ON* + // clang-format on + + int q = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; q + 15 < outh; q += 16) + { + const float* kptr0 = (const float*)kernel + q * inh * kernel_w; + const float* kptr1 = (const float*)kernel + (q + 1) * inh * kernel_w; + const float* kptr2 = (const float*)kernel + (q + 2) * inh * kernel_w; + const float* kptr3 = (const float*)kernel + (q + 3) * inh * kernel_w; + const float* kptr4 = (const float*)kernel + (q + 4) * inh * kernel_w; + const float* kptr5 = (const float*)kernel + (q + 5) * inh * kernel_w; + const float* kptr6 = (const float*)kernel + (q + 6) * inh * kernel_w; + const float* kptr7 = (const float*)kernel + (q + 7) * inh * kernel_w; + const float* kptr8 = (const float*)kernel + (q + 8) * inh * kernel_w; + const float* kptr9 = (const float*)kernel + (q + 9) * inh * kernel_w; + const float* kptra = (const float*)kernel + (q + 10) * inh * kernel_w; + const float* kptrb = (const float*)kernel + (q + 11) * inh * kernel_w; + const float* kptrc = (const float*)kernel + (q + 12) * inh * kernel_w; + const float* kptrd = (const float*)kernel + (q + 13) * inh * kernel_w; + const float* kptre = (const float*)kernel + (q + 14) * inh * kernel_w; + const float* kptrf = (const float*)kernel + (q + 15) * inh * kernel_w; + + float* g00 = kernel_tm.channel(q / 16); + + int p = 0; +#if __AVX__ +#if __AVX512F__ + for (; p + 15 < inh; p += 16) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + const float* k8 = kptr8 + p * kernel_w; + const float* k9 = kptr9 + p * kernel_w; + const float* ka = kptra + p * kernel_w; + const float* kb = kptrb + p * kernel_w; + const float* kc = kptrc + p * kernel_w; + const float* kd = kptrd + p * kernel_w; + const float* ke = kptre + p * kernel_w; + const float* kf = kptrf + p * kernel_w; + + for (int i = 0; i < 16; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + g00[4] = k4[k]; + g00[5] = k5[k]; + g00[6] = k6[k]; + g00[7] = k7[k]; + g00[8] = k8[k]; + g00[9] = k9[k]; + g00[10] = ka[k]; + g00[11] = kb[k]; + g00[12] = kc[k]; + g00[13] = kd[k]; + g00[14] = ke[k]; + g00[15] = kf[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + k4 += kernel_w; + k5 += kernel_w; + k6 += kernel_w; + k7 += kernel_w; + k8 += kernel_w; + k9 += kernel_w; + ka += kernel_w; + kb += kernel_w; + kc += kernel_w; + kd += kernel_w; + ke += kernel_w; + kf += kernel_w; + g00 += 16; + } + } + } +#endif // __AVX512F__ + for (; p + 7 < inh; p += 8) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + const float* k8 = kptr8 + p * kernel_w; + const float* k9 = kptr9 + p * kernel_w; + const float* ka = kptra + p * kernel_w; + const float* kb = kptrb + p * kernel_w; + const float* kc = kptrc + p * kernel_w; + const float* kd = kptrd + p * kernel_w; + const float* ke = kptre + p * kernel_w; + const float* kf = kptrf + p * kernel_w; + + for (int i = 0; i < 8; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + g00[4] = k4[k]; + g00[5] = k5[k]; + g00[6] = k6[k]; + g00[7] = k7[k]; + g00[8] = k8[k]; + g00[9] = k9[k]; + g00[10] = ka[k]; + g00[11] = kb[k]; + g00[12] = kc[k]; + g00[13] = kd[k]; + g00[14] = ke[k]; + g00[15] = kf[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + k4 += kernel_w; + k5 += kernel_w; + k6 += kernel_w; + k7 += kernel_w; + k8 += kernel_w; + k9 += kernel_w; + ka += kernel_w; + kb += kernel_w; + kc += kernel_w; + kd += kernel_w; + ke += kernel_w; + kf += kernel_w; + g00 += 16; + } + } + } +#endif // __AVX__ + for (; p + 3 < inh; p += 4) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + const float* k8 = kptr8 + p * kernel_w; + const float* k9 = kptr9 + p * kernel_w; + const float* ka = kptra + p * kernel_w; + const float* kb = kptrb + p * kernel_w; + const float* kc = kptrc + p * kernel_w; + const float* kd = kptrd + p * kernel_w; + const float* ke = kptre + p * kernel_w; + const float* kf = kptrf + p * kernel_w; + + for (int i = 0; i < 4; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + g00[4] = k4[k]; + g00[5] = k5[k]; + g00[6] = k6[k]; + g00[7] = k7[k]; + g00[8] = k8[k]; + g00[9] = k9[k]; + g00[10] = ka[k]; + g00[11] = kb[k]; + g00[12] = kc[k]; + g00[13] = kd[k]; + g00[14] = ke[k]; + g00[15] = kf[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + k4 += kernel_w; + k5 += kernel_w; + k6 += kernel_w; + k7 += kernel_w; + k8 += kernel_w; + k9 += kernel_w; + ka += kernel_w; + kb += kernel_w; + kc += kernel_w; + kd += kernel_w; + ke += kernel_w; + kf += kernel_w; + g00 += 16; + } + } + } + for (; p + 1 < inh; p += 2) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + const float* k8 = kptr8 + p * kernel_w; + const float* k9 = kptr9 + p * kernel_w; + const float* ka = kptra + p * kernel_w; + const float* kb = kptrb + p * kernel_w; + const float* kc = kptrc + p * kernel_w; + const float* kd = kptrd + p * kernel_w; + const float* ke = kptre + p * kernel_w; + const float* kf = kptrf + p * kernel_w; + + for (int i = 0; i < 2; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + g00[4] = k4[k]; + g00[5] = k5[k]; + g00[6] = k6[k]; + g00[7] = k7[k]; + g00[8] = k8[k]; + g00[9] = k9[k]; + g00[10] = ka[k]; + g00[11] = kb[k]; + g00[12] = kc[k]; + g00[13] = kd[k]; + g00[14] = ke[k]; + g00[15] = kf[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + k4 += kernel_w; + k5 += kernel_w; + k6 += kernel_w; + k7 += kernel_w; + k8 += kernel_w; + k9 += kernel_w; + ka += kernel_w; + kb += kernel_w; + kc += kernel_w; + kd += kernel_w; + ke += kernel_w; + kf += kernel_w; + g00 += 16; + } + } + } + for (; p < inh; p++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + const float* k8 = kptr8 + p * kernel_w; + const float* k9 = kptr9 + p * kernel_w; + const float* ka = kptra + p * kernel_w; + const float* kb = kptrb + p * kernel_w; + const float* kc = kptrc + p * kernel_w; + const float* kd = kptrd + p * kernel_w; + const float* ke = kptre + p * kernel_w; + const float* kf = kptrf + p * kernel_w; + + for (int k = 0; k < kernel_w; k++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + g00[4] = k4[k]; + g00[5] = k5[k]; + g00[6] = k6[k]; + g00[7] = k7[k]; + g00[8] = k8[k]; + g00[9] = k9[k]; + g00[10] = ka[k]; + g00[11] = kb[k]; + g00[12] = kc[k]; + g00[13] = kd[k]; + g00[14] = ke[k]; + g00[15] = kf[k]; + g00 += 16; + } + } + } +#endif // __AVX512F__ + for (; q + 7 < outh; q += 8) + { + const float* kptr0 = (const float*)kernel + q * inh * kernel_w; + const float* kptr1 = (const float*)kernel + (q + 1) * inh * kernel_w; + const float* kptr2 = (const float*)kernel + (q + 2) * inh * kernel_w; + const float* kptr3 = (const float*)kernel + (q + 3) * inh * kernel_w; + const float* kptr4 = (const float*)kernel + (q + 4) * inh * kernel_w; + const float* kptr5 = (const float*)kernel + (q + 5) * inh * kernel_w; + const float* kptr6 = (const float*)kernel + (q + 6) * inh * kernel_w; + const float* kptr7 = (const float*)kernel + (q + 7) * inh * kernel_w; + +#if __AVX512F__ + float* g00 = kernel_tm.channel(q / 16 + (q % 16) / 8); +#else + float* g00 = kernel_tm.channel(q / 8); +#endif + + int p = 0; +#if __AVX512F__ + for (; p + 15 < inh; p += 16) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + + for (int i = 0; i < 16; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + g00[4] = k4[k]; + g00[5] = k5[k]; + g00[6] = k6[k]; + g00[7] = k7[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + k4 += kernel_w; + k5 += kernel_w; + k6 += kernel_w; + k7 += kernel_w; + g00 += 8; + } + } + } +#endif // __AVX512F__ + for (; p + 7 < inh; p += 8) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + + for (int i = 0; i < 8; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + g00[4] = k4[k]; + g00[5] = k5[k]; + g00[6] = k6[k]; + g00[7] = k7[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + k4 += kernel_w; + k5 += kernel_w; + k6 += kernel_w; + k7 += kernel_w; + g00 += 8; + } + } + } + for (; p + 3 < inh; p += 4) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + + for (int i = 0; i < 4; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + g00[4] = k4[k]; + g00[5] = k5[k]; + g00[6] = k6[k]; + g00[7] = k7[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + k4 += kernel_w; + k5 += kernel_w; + k6 += kernel_w; + k7 += kernel_w; + g00 += 8; + } + } + } + for (; p + 1 < inh; p += 2) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + + for (int i = 0; i < 2; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + g00[4] = k4[k]; + g00[5] = k5[k]; + g00[6] = k6[k]; + g00[7] = k7[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + k4 += kernel_w; + k5 += kernel_w; + k6 += kernel_w; + k7 += kernel_w; + g00 += 8; + } + } + } + for (; p < inh; p++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + const float* k4 = kptr4 + p * kernel_w; + const float* k5 = kptr5 + p * kernel_w; + const float* k6 = kptr6 + p * kernel_w; + const float* k7 = kptr7 + p * kernel_w; + + for (int k = 0; k < kernel_w; k++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + g00[4] = k4[k]; + g00[5] = k5[k]; + g00[6] = k6[k]; + g00[7] = k7[k]; + g00 += 8; + } + } + } +#endif // __AVX__ + for (; q + 3 < outh; q += 4) + { + const float* kptr0 = (const float*)kernel + q * inh * kernel_w; + const float* kptr1 = (const float*)kernel + (q + 1) * inh * kernel_w; + const float* kptr2 = (const float*)kernel + (q + 2) * inh * kernel_w; + const float* kptr3 = (const float*)kernel + (q + 3) * inh * kernel_w; + +#if __AVX512F__ + float* g00 = kernel_tm.channel(q / 16 + (q % 16) / 8 + (q % 8) / 4); +#elif __AVX__ + float* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4); +#else + float* g00 = kernel_tm.channel(q / 4); +#endif + + int p = 0; +#if __AVX__ +#if __AVX512F__ + for (; p + 15 < inh; p += 16) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + + for (int i = 0; i < 16; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + g00 += 4; + } + } + } +#endif // __AVX512F__ + for (; p + 7 < inh; p += 8) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + + for (int i = 0; i < 8; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + g00 += 4; + } + } + } +#endif // __AVX__ + for (; p + 3 < inh; p += 4) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + + for (int i = 0; i < 4; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + g00 += 4; + } + } + } + for (; p + 1 < inh; p += 2) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + + for (int i = 0; i < 2; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + k0 += kernel_w; + k1 += kernel_w; + k2 += kernel_w; + k3 += kernel_w; + g00 += 4; + } + } + } + for (; p < inh; p++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + const float* k2 = kptr2 + p * kernel_w; + const float* k3 = kptr3 + p * kernel_w; + + for (int k = 0; k < kernel_w; k++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00[2] = k2[k]; + g00[3] = k3[k]; + g00 += 4; + } + } + } +#endif // __SSE2__ + for (; q + 1 < outh; q += 2) + { + const float* kptr0 = (const float*)kernel + q * inh * kernel_w; + const float* kptr1 = (const float*)kernel + (q + 1) * inh * kernel_w; + +#if __AVX512F__ + float* g00 = kernel_tm.channel(q / 16 + (q % 16) / 8 + (q % 8) / 4 + (q % 4) / 2); +#elif __AVX__ + float* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4 + (q % 4) / 2); +#elif __SSE2__ + float* g00 = kernel_tm.channel(q / 4 + (q % 4) / 2); +#else + float* g00 = kernel_tm.channel(q / 2); +#endif + + int p = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; p + 15 < inh; p += 16) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w + k; + const float* k1 = kptr1 + p * kernel_w + k; + + g00[0] = k0[0]; + g00[1] = k0[kernel_w]; + g00[2] = k0[kernel_w * 2]; + g00[3] = k0[kernel_w * 3]; + g00[4] = k0[kernel_w * 4]; + g00[5] = k0[kernel_w * 5]; + g00[6] = k0[kernel_w * 6]; + g00[7] = k0[kernel_w * 7]; + g00[8] = k0[kernel_w * 8]; + g00[9] = k0[kernel_w * 9]; + g00[10] = k0[kernel_w * 10]; + g00[11] = k0[kernel_w * 11]; + g00[12] = k0[kernel_w * 12]; + g00[13] = k0[kernel_w * 13]; + g00[14] = k0[kernel_w * 14]; + g00[15] = k0[kernel_w * 15]; + g00[16] = k1[0]; + g00[17] = k1[kernel_w]; + g00[18] = k1[kernel_w * 2]; + g00[19] = k1[kernel_w * 3]; + g00[20] = k1[kernel_w * 4]; + g00[21] = k1[kernel_w * 5]; + g00[22] = k1[kernel_w * 6]; + g00[23] = k1[kernel_w * 7]; + g00[24] = k1[kernel_w * 8]; + g00[25] = k1[kernel_w * 9]; + g00[26] = k1[kernel_w * 10]; + g00[27] = k1[kernel_w * 11]; + g00[28] = k1[kernel_w * 12]; + g00[29] = k1[kernel_w * 13]; + g00[30] = k1[kernel_w * 14]; + g00[31] = k1[kernel_w * 15]; + g00 += 32; + } + } +#endif // __AVX512F__ + for (; p + 7 < inh; p += 8) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w + k; + const float* k1 = kptr1 + p * kernel_w + k; + + g00[0] = k0[0]; + g00[1] = k0[kernel_w]; + g00[2] = k0[kernel_w * 2]; + g00[3] = k0[kernel_w * 3]; + g00[4] = k0[kernel_w * 4]; + g00[5] = k0[kernel_w * 5]; + g00[6] = k0[kernel_w * 6]; + g00[7] = k0[kernel_w * 7]; + g00[8] = k1[0]; + g00[9] = k1[kernel_w]; + g00[10] = k1[kernel_w * 2]; + g00[11] = k1[kernel_w * 3]; + g00[12] = k1[kernel_w * 4]; + g00[13] = k1[kernel_w * 5]; + g00[14] = k1[kernel_w * 6]; + g00[15] = k1[kernel_w * 7]; + g00 += 16; + } + } +#endif // __AVX__ + for (; p + 3 < inh; p += 4) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w + k; + const float* k1 = kptr1 + p * kernel_w + k; + + g00[0] = k0[0]; + g00[1] = k0[kernel_w]; + g00[2] = k0[kernel_w * 2]; + g00[3] = k0[kernel_w * 3]; + g00[4] = k1[0]; + g00[5] = k1[kernel_w]; + g00[6] = k1[kernel_w * 2]; + g00[7] = k1[kernel_w * 3]; + g00 += 8; + } + } +#endif // __SSE2__ + for (; p + 1 < inh; p += 2) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + + for (int i = 0; i < 2; i++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + k0 += kernel_w; + k1 += kernel_w; + g00 += 2; + } + } + } + for (; p < inh; p++) + { + const float* k0 = kptr0 + p * kernel_w; + const float* k1 = kptr1 + p * kernel_w; + + for (int k = 0; k < kernel_w; k++) + { + g00[0] = k0[k]; + g00[1] = k1[k]; + g00 += 2; + } + } + } + for (; q < outh; q++) + { + const float* kptr = (const float*)kernel + q * inh * kernel_w; + +#if __AVX512F__ + float* g00 = kernel_tm.channel(q / 16 + (q % 16) / 8 + (q % 8) / 4 + (q % 4) / 2 + q % 2); +#elif __AVX__ + float* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4 + (q % 4) / 2 + q % 2); +#elif __SSE2__ + float* g00 = kernel_tm.channel(q / 4 + (q % 4) / 2 + q % 2); +#else + float* g00 = kernel_tm.channel(q / 2 + q % 2); +#endif + + int p = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; p + 15 < inh; p += 16) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr + p * kernel_w; + + for (int i = 0; i < 16; i++) + { + g00[0] = k0[k]; + k0 += kernel_w; + g00 += 1; + } + } + } +#endif // __AVX512F__ + for (; p + 7 < inh; p += 8) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr + p * kernel_w; + + for (int i = 0; i < 8; i++) + { + g00[0] = k0[k]; + k0 += kernel_w; + g00 += 1; + } + } + } +#endif // __AVX__ + for (; p + 3 < inh; p += 4) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr + p * kernel_w; + + for (int i = 0; i < 4; i++) + { + g00[0] = k0[k]; + k0 += kernel_w; + g00 += 1; + } + } + } +#endif // __SSE2__ + for (; p + 1 < inh; p += 2) + { + for (int k = 0; k < kernel_w; k++) + { + const float* k0 = kptr + p * kernel_w; + + for (int i = 0; i < 2; i++) + { + g00[0] = k0[k]; + k0 += kernel_w; + g00 += 1; + } + } + } + for (; p < inh; p++) + { + const float* k0 = kptr + p * kernel_w; + + for (int k = 0; k < kernel_w; k++) + { + g00[0] = k0[k]; + g00++; + } + } + } +} + +static void convolution1d_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, const Mat& bias_data, int kernel_w, int dilation_w, int stride_w, int activation_type, const Mat& activation_params, const Option& opt) +{ + const int elempack = bottom_blob.elempack; + const int inh = bottom_blob.h * elempack; + + const int N = bottom_blob.w * elempack; + + const int outw = top_blob.w; + const int out_elempack = top_blob.elempack; + const int outh = top_blob.h * out_elempack; + + const int M = top_blob.w * out_elempack; + + const float* bias_data_ptr = bias_data; + + int nn_outh = 0; + int remain_outh_start = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + nn_outh = outh / 16; + #pragma omp parallel for num_threads(opt.num_threads) + for (int pp = 0; pp < nn_outh; pp++) + { + const int p = pp * 16; + + float* outptr = top_blob.row(p / out_elempack); + + for (int j = 0; j < outw; j++) + { + __m512 _sum0 = _mm512_setzero_ps(); + __m512 _sum1 = _mm512_setzero_ps(); + __m512 _sum2 = _mm512_setzero_ps(); + __m512 _sum3 = _mm512_setzero_ps(); + + if (bias_data_ptr) + { + _sum0 = _mm512_loadu_ps(bias_data_ptr + p); + } + + const float* kptr = weight_data_tm.channel(p / 16); + + int q = 0; + for (; q + 15 < inh; q += 16) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + if (elempack == 16) + { + for (int k = 0; k < kernel_w; k++) + { + __m512 _w0 = _mm512_load_ps(kptr + 16 * 0); + __m512 _w1 = _mm512_load_ps(kptr + 16 * 1); + __m512 _w2 = _mm512_load_ps(kptr + 16 * 2); + __m512 _w3 = _mm512_load_ps(kptr + 16 * 3); + __m512 _w4 = _mm512_load_ps(kptr + 16 * 4); + __m512 _w5 = _mm512_load_ps(kptr + 16 * 5); + __m512 _w6 = _mm512_load_ps(kptr + 16 * 6); + __m512 _w7 = _mm512_load_ps(kptr + 16 * 7); + __m512 _w8 = _mm512_load_ps(kptr + 16 * 8); + __m512 _w9 = _mm512_load_ps(kptr + 16 * 9); + __m512 _wa = _mm512_load_ps(kptr + 16 * 10); + __m512 _wb = _mm512_load_ps(kptr + 16 * 11); + __m512 _wc = _mm512_load_ps(kptr + 16 * 12); + __m512 _wd = _mm512_load_ps(kptr + 16 * 13); + __m512 _we = _mm512_load_ps(kptr + 16 * 14); + __m512 _wf = _mm512_load_ps(kptr + 16 * 15); + + _sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[1]), _sum1); + _sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[2]), _sum2); + _sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[3]), _sum3); + _sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r0[4]), _sum0); + _sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r0[5]), _sum1); + _sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r0[6]), _sum2); + _sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r0[7]), _sum3); + _sum0 = _mm512_fmadd_ps(_w8, _mm512_set1_ps(r0[8]), _sum0); + _sum1 = _mm512_fmadd_ps(_w9, _mm512_set1_ps(r0[9]), _sum1); + _sum2 = _mm512_fmadd_ps(_wa, _mm512_set1_ps(r0[10]), _sum2); + _sum3 = _mm512_fmadd_ps(_wb, _mm512_set1_ps(r0[11]), _sum3); + _sum0 = _mm512_fmadd_ps(_wc, _mm512_set1_ps(r0[12]), _sum0); + _sum1 = _mm512_fmadd_ps(_wd, _mm512_set1_ps(r0[13]), _sum1); + _sum2 = _mm512_fmadd_ps(_we, _mm512_set1_ps(r0[14]), _sum2); + _sum3 = _mm512_fmadd_ps(_wf, _mm512_set1_ps(r0[15]), _sum3); + + r0 += dilation_w * 16; + kptr += 256; + } + } + if (elempack == 8) + { + const float* r1 = r0 + N; + + for (int k = 0; k < kernel_w; k++) + { + __m512 _w0 = _mm512_load_ps(kptr + 16 * 0); + __m512 _w1 = _mm512_load_ps(kptr + 16 * 1); + __m512 _w2 = _mm512_load_ps(kptr + 16 * 2); + __m512 _w3 = _mm512_load_ps(kptr + 16 * 3); + __m512 _w4 = _mm512_load_ps(kptr + 16 * 4); + __m512 _w5 = _mm512_load_ps(kptr + 16 * 5); + __m512 _w6 = _mm512_load_ps(kptr + 16 * 6); + __m512 _w7 = _mm512_load_ps(kptr + 16 * 7); + __m512 _w8 = _mm512_load_ps(kptr + 16 * 8); + __m512 _w9 = _mm512_load_ps(kptr + 16 * 9); + __m512 _wa = _mm512_load_ps(kptr + 16 * 10); + __m512 _wb = _mm512_load_ps(kptr + 16 * 11); + __m512 _wc = _mm512_load_ps(kptr + 16 * 12); + __m512 _wd = _mm512_load_ps(kptr + 16 * 13); + __m512 _we = _mm512_load_ps(kptr + 16 * 14); + __m512 _wf = _mm512_load_ps(kptr + 16 * 15); + + _sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[1]), _sum1); + _sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[2]), _sum2); + _sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[3]), _sum3); + _sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r0[4]), _sum0); + _sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r0[5]), _sum1); + _sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r0[6]), _sum2); + _sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r0[7]), _sum3); + _sum0 = _mm512_fmadd_ps(_w8, _mm512_set1_ps(r1[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_w9, _mm512_set1_ps(r1[1]), _sum1); + _sum2 = _mm512_fmadd_ps(_wa, _mm512_set1_ps(r1[2]), _sum2); + _sum3 = _mm512_fmadd_ps(_wb, _mm512_set1_ps(r1[3]), _sum3); + _sum0 = _mm512_fmadd_ps(_wc, _mm512_set1_ps(r1[4]), _sum0); + _sum1 = _mm512_fmadd_ps(_wd, _mm512_set1_ps(r1[5]), _sum1); + _sum2 = _mm512_fmadd_ps(_we, _mm512_set1_ps(r1[6]), _sum2); + _sum3 = _mm512_fmadd_ps(_wf, _mm512_set1_ps(r1[7]), _sum3); + + r0 += dilation_w * 8; + r1 += dilation_w * 8; + kptr += 256; + } + } + if (elempack == 4) + { + const float* r1 = r0 + N; + const float* r2 = r0 + N * 2; + const float* r3 = r0 + N * 3; + + for (int k = 0; k < kernel_w; k++) + { + __m512 _w0 = _mm512_load_ps(kptr + 16 * 0); + __m512 _w1 = _mm512_load_ps(kptr + 16 * 1); + __m512 _w2 = _mm512_load_ps(kptr + 16 * 2); + __m512 _w3 = _mm512_load_ps(kptr + 16 * 3); + __m512 _w4 = _mm512_load_ps(kptr + 16 * 4); + __m512 _w5 = _mm512_load_ps(kptr + 16 * 5); + __m512 _w6 = _mm512_load_ps(kptr + 16 * 6); + __m512 _w7 = _mm512_load_ps(kptr + 16 * 7); + __m512 _w8 = _mm512_load_ps(kptr + 16 * 8); + __m512 _w9 = _mm512_load_ps(kptr + 16 * 9); + __m512 _wa = _mm512_load_ps(kptr + 16 * 10); + __m512 _wb = _mm512_load_ps(kptr + 16 * 11); + __m512 _wc = _mm512_load_ps(kptr + 16 * 12); + __m512 _wd = _mm512_load_ps(kptr + 16 * 13); + __m512 _we = _mm512_load_ps(kptr + 16 * 14); + __m512 _wf = _mm512_load_ps(kptr + 16 * 15); + + _sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[1]), _sum1); + _sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[2]), _sum2); + _sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[3]), _sum3); + _sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r1[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r1[1]), _sum1); + _sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r1[2]), _sum2); + _sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r1[3]), _sum3); + _sum0 = _mm512_fmadd_ps(_w8, _mm512_set1_ps(r2[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_w9, _mm512_set1_ps(r2[1]), _sum1); + _sum2 = _mm512_fmadd_ps(_wa, _mm512_set1_ps(r2[2]), _sum2); + _sum3 = _mm512_fmadd_ps(_wb, _mm512_set1_ps(r2[3]), _sum3); + _sum0 = _mm512_fmadd_ps(_wc, _mm512_set1_ps(r3[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_wd, _mm512_set1_ps(r3[1]), _sum1); + _sum2 = _mm512_fmadd_ps(_we, _mm512_set1_ps(r3[2]), _sum2); + _sum3 = _mm512_fmadd_ps(_wf, _mm512_set1_ps(r3[3]), _sum3); + + r0 += dilation_w * 4; + r1 += dilation_w * 4; + r2 += dilation_w * 4; + r3 += dilation_w * 4; + kptr += 256; + } + } + if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m512 _w0 = _mm512_load_ps(kptr + 16 * 0); + __m512 _w1 = _mm512_load_ps(kptr + 16 * 1); + __m512 _w2 = _mm512_load_ps(kptr + 16 * 2); + __m512 _w3 = _mm512_load_ps(kptr + 16 * 3); + __m512 _w4 = _mm512_load_ps(kptr + 16 * 4); + __m512 _w5 = _mm512_load_ps(kptr + 16 * 5); + __m512 _w6 = _mm512_load_ps(kptr + 16 * 6); + __m512 _w7 = _mm512_load_ps(kptr + 16 * 7); + __m512 _w8 = _mm512_load_ps(kptr + 16 * 8); + __m512 _w9 = _mm512_load_ps(kptr + 16 * 9); + __m512 _wa = _mm512_load_ps(kptr + 16 * 10); + __m512 _wb = _mm512_load_ps(kptr + 16 * 11); + __m512 _wc = _mm512_load_ps(kptr + 16 * 12); + __m512 _wd = _mm512_load_ps(kptr + 16 * 13); + __m512 _we = _mm512_load_ps(kptr + 16 * 14); + __m512 _wf = _mm512_load_ps(kptr + 16 * 15); + + _sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[N]), _sum1); + _sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[N * 2]), _sum2); + _sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[N * 3]), _sum3); + _sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r0[N * 4]), _sum0); + _sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r0[N * 5]), _sum1); + _sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r0[N * 6]), _sum2); + _sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r0[N * 7]), _sum3); + _sum0 = _mm512_fmadd_ps(_w8, _mm512_set1_ps(r0[N * 8]), _sum0); + _sum1 = _mm512_fmadd_ps(_w9, _mm512_set1_ps(r0[N * 9]), _sum1); + _sum2 = _mm512_fmadd_ps(_wa, _mm512_set1_ps(r0[N * 10]), _sum2); + _sum3 = _mm512_fmadd_ps(_wb, _mm512_set1_ps(r0[N * 11]), _sum3); + _sum0 = _mm512_fmadd_ps(_wc, _mm512_set1_ps(r0[N * 12]), _sum0); + _sum1 = _mm512_fmadd_ps(_wd, _mm512_set1_ps(r0[N * 13]), _sum1); + _sum2 = _mm512_fmadd_ps(_we, _mm512_set1_ps(r0[N * 14]), _sum2); + _sum3 = _mm512_fmadd_ps(_wf, _mm512_set1_ps(r0[N * 15]), _sum3); + + r0 += dilation_w; + kptr += 256; + } + } + } + for (; q + 7 < inh; q += 8) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + if (elempack == 8) + { + for (int k = 0; k < kernel_w; k++) + { + __m512 _w0 = _mm512_load_ps(kptr + 16 * 0); + __m512 _w1 = _mm512_load_ps(kptr + 16 * 1); + __m512 _w2 = _mm512_load_ps(kptr + 16 * 2); + __m512 _w3 = _mm512_load_ps(kptr + 16 * 3); + __m512 _w4 = _mm512_load_ps(kptr + 16 * 4); + __m512 _w5 = _mm512_load_ps(kptr + 16 * 5); + __m512 _w6 = _mm512_load_ps(kptr + 16 * 6); + __m512 _w7 = _mm512_load_ps(kptr + 16 * 7); + + _sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[1]), _sum1); + _sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[2]), _sum2); + _sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[3]), _sum3); + _sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r0[4]), _sum0); + _sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r0[5]), _sum1); + _sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r0[6]), _sum2); + _sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r0[7]), _sum3); + + r0 += dilation_w * 8; + kptr += 128; + } + } + if (elempack == 4) + { + const float* r1 = r0 + N; + + for (int k = 0; k < kernel_w; k++) + { + __m512 _w0 = _mm512_load_ps(kptr + 16 * 0); + __m512 _w1 = _mm512_load_ps(kptr + 16 * 1); + __m512 _w2 = _mm512_load_ps(kptr + 16 * 2); + __m512 _w3 = _mm512_load_ps(kptr + 16 * 3); + __m512 _w4 = _mm512_load_ps(kptr + 16 * 4); + __m512 _w5 = _mm512_load_ps(kptr + 16 * 5); + __m512 _w6 = _mm512_load_ps(kptr + 16 * 6); + __m512 _w7 = _mm512_load_ps(kptr + 16 * 7); + + _sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[1]), _sum1); + _sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[2]), _sum2); + _sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[3]), _sum3); + _sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r1[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r1[1]), _sum1); + _sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r1[2]), _sum2); + _sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r1[3]), _sum3); + + r0 += dilation_w * 4; + r1 += dilation_w * 4; + kptr += 128; + } + } + if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m512 _w0 = _mm512_load_ps(kptr + 16 * 0); + __m512 _w1 = _mm512_load_ps(kptr + 16 * 1); + __m512 _w2 = _mm512_load_ps(kptr + 16 * 2); + __m512 _w3 = _mm512_load_ps(kptr + 16 * 3); + __m512 _w4 = _mm512_load_ps(kptr + 16 * 4); + __m512 _w5 = _mm512_load_ps(kptr + 16 * 5); + __m512 _w6 = _mm512_load_ps(kptr + 16 * 6); + __m512 _w7 = _mm512_load_ps(kptr + 16 * 7); + + _sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[N]), _sum1); + _sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[N * 2]), _sum2); + _sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[N * 3]), _sum3); + _sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r0[N * 4]), _sum0); + _sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r0[N * 5]), _sum1); + _sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r0[N * 6]), _sum2); + _sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r0[N * 7]), _sum3); + + r0 += dilation_w; + kptr += 128; + } + } + } + for (; q + 3 < inh; q += 4) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + if (elempack == 4) + { + for (int k = 0; k < kernel_w; k++) + { + __m512 _w0 = _mm512_load_ps(kptr); + __m512 _w1 = _mm512_load_ps(kptr + 16); + __m512 _w2 = _mm512_load_ps(kptr + 32); + __m512 _w3 = _mm512_load_ps(kptr + 48); + + _sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[1]), _sum1); + _sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[2]), _sum2); + _sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[3]), _sum3); + + r0 += dilation_w * 4; + kptr += 64; + } + } + if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m512 _w0 = _mm512_load_ps(kptr); + __m512 _w1 = _mm512_load_ps(kptr + 16); + __m512 _w2 = _mm512_load_ps(kptr + 32); + __m512 _w3 = _mm512_load_ps(kptr + 48); + + _sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[N]), _sum1); + _sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[N * 2]), _sum2); + _sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[N * 3]), _sum3); + + r0 += dilation_w; + kptr += 64; + } + } + } + for (; q + 1 < inh; q += 2) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + // if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m512 _w0 = _mm512_load_ps(kptr); + __m512 _w1 = _mm512_load_ps(kptr + 16); + + _sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[N]), _sum1); + + r0 += dilation_w; + kptr += 32; + } + } + } + for (; q < inh; q++) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + // if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m512 _val = _mm512_set1_ps(r0[0]); + __m512 _w = _mm512_load_ps(kptr); + _sum0 = _mm512_fmadd_ps(_val, _w, _sum0); + + r0 += dilation_w; + kptr += 16; + } + } + } + + _sum0 = _mm512_add_ps(_sum0, _sum1); + _sum2 = _mm512_add_ps(_sum2, _sum3); + _sum0 = _mm512_add_ps(_sum0, _sum2); + + _sum0 = activation_avx512(_sum0, activation_type, activation_params); + + if (out_elempack == 16) + { + _mm512_store_ps(outptr, _sum0); + outptr += 16; + } + if (out_elempack == 8) + { + _mm256_store_ps(outptr, _mm512_extractf32x8_ps(_sum0, 0)); + _mm256_store_ps(outptr + M, _mm512_extractf32x8_ps(_sum0, 1)); + outptr += 8; + } + if (out_elempack == 4) + { + _mm_store_ps(outptr, _mm512_extractf32x4_ps(_sum0, 0)); + _mm_store_ps(outptr + M, _mm512_extractf32x4_ps(_sum0, 1)); + _mm_store_ps(outptr + M * 2, _mm512_extractf32x4_ps(_sum0, 2)); + _mm_store_ps(outptr + M * 3, _mm512_extractf32x4_ps(_sum0, 3)); + outptr += 4; + } + if (out_elempack == 1) + { + float sum[16]; + _mm512_storeu_ps(sum, _sum0); + + outptr[0] = sum[0]; + outptr[M] = sum[1]; + outptr[M * 2] = sum[2]; + outptr[M * 3] = sum[3]; + outptr[M * 4] = sum[4]; + outptr[M * 5] = sum[5]; + outptr[M * 6] = sum[6]; + outptr[M * 7] = sum[7]; + outptr[M * 8] = sum[8]; + outptr[M * 9] = sum[9]; + outptr[M * 10] = sum[10]; + outptr[M * 11] = sum[11]; + outptr[M * 12] = sum[12]; + outptr[M * 13] = sum[13]; + outptr[M * 14] = sum[14]; + outptr[M * 15] = sum[15]; + outptr += 1; + } + } + } + remain_outh_start += nn_outh * 16; + nn_outh = (outh - remain_outh_start) / 8; +#else // __AVX512F__ + nn_outh = (outh - remain_outh_start) / 8; + #pragma omp parallel for num_threads(opt.num_threads) +#endif // __AVX512F__ + for (int pp = 0; pp < nn_outh; pp++) + { + const int p = remain_outh_start + pp * 8; + + float* outptr = top_blob.row(p / out_elempack); + + for (int j = 0; j < outw; j++) + { + __m256 _sum0 = _mm256_setzero_ps(); + __m256 _sum1 = _mm256_setzero_ps(); + __m256 _sum2 = _mm256_setzero_ps(); + __m256 _sum3 = _mm256_setzero_ps(); + + if (bias_data_ptr) + { + _sum0 = _mm256_loadu_ps(bias_data_ptr + p); + } + +#if __AVX512F__ + const float* kptr = weight_data_tm.channel(p / 16 + (p % 16) / 8); +#else + const float* kptr = weight_data_tm.channel(p / 8); +#endif + + int q = 0; +#if __AVX512F__ + for (; q + 15 < inh; q += 16) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + if (elempack == 16) + { + for (int k = 0; k < kernel_w; k++) + { + __m256 _w0 = _mm256_load_ps(kptr + 8 * 0); + __m256 _w1 = _mm256_load_ps(kptr + 8 * 1); + __m256 _w2 = _mm256_load_ps(kptr + 8 * 2); + __m256 _w3 = _mm256_load_ps(kptr + 8 * 3); + __m256 _w4 = _mm256_load_ps(kptr + 8 * 4); + __m256 _w5 = _mm256_load_ps(kptr + 8 * 5); + __m256 _w6 = _mm256_load_ps(kptr + 8 * 6); + __m256 _w7 = _mm256_load_ps(kptr + 8 * 7); + __m256 _w8 = _mm256_load_ps(kptr + 8 * 8); + __m256 _w9 = _mm256_load_ps(kptr + 8 * 9); + __m256 _wa = _mm256_load_ps(kptr + 8 * 10); + __m256 _wb = _mm256_load_ps(kptr + 8 * 11); + __m256 _wc = _mm256_load_ps(kptr + 8 * 12); + __m256 _wd = _mm256_load_ps(kptr + 8 * 13); + __m256 _we = _mm256_load_ps(kptr + 8 * 14); + __m256 _wf = _mm256_load_ps(kptr + 8 * 15); + + _sum0 = _mm256_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0); + _sum1 = _mm256_fmadd_ps(_w1, _mm256_set1_ps(r0[1]), _sum1); + _sum2 = _mm256_fmadd_ps(_w2, _mm256_set1_ps(r0[2]), _sum2); + _sum3 = _mm256_fmadd_ps(_w3, _mm256_set1_ps(r0[3]), _sum3); + _sum0 = _mm256_fmadd_ps(_w4, _mm256_set1_ps(r0[4]), _sum0); + _sum1 = _mm256_fmadd_ps(_w5, _mm256_set1_ps(r0[5]), _sum1); + _sum2 = _mm256_fmadd_ps(_w6, _mm256_set1_ps(r0[6]), _sum2); + _sum3 = _mm256_fmadd_ps(_w7, _mm256_set1_ps(r0[7]), _sum3); + _sum0 = _mm256_fmadd_ps(_w8, _mm256_set1_ps(r0[8]), _sum0); + _sum1 = _mm256_fmadd_ps(_w9, _mm256_set1_ps(r0[9]), _sum1); + _sum2 = _mm256_fmadd_ps(_wa, _mm256_set1_ps(r0[10]), _sum2); + _sum3 = _mm256_fmadd_ps(_wb, _mm256_set1_ps(r0[11]), _sum3); + _sum0 = _mm256_fmadd_ps(_wc, _mm256_set1_ps(r0[12]), _sum0); + _sum1 = _mm256_fmadd_ps(_wd, _mm256_set1_ps(r0[13]), _sum1); + _sum2 = _mm256_fmadd_ps(_we, _mm256_set1_ps(r0[14]), _sum2); + _sum3 = _mm256_fmadd_ps(_wf, _mm256_set1_ps(r0[15]), _sum3); + + r0 += dilation_w * 16; + kptr += 128; + } + } + if (elempack == 8) + { + const float* r1 = r0 + N; + + for (int k = 0; k < kernel_w; k++) + { + __m256 _w0 = _mm256_load_ps(kptr + 8 * 0); + __m256 _w1 = _mm256_load_ps(kptr + 8 * 1); + __m256 _w2 = _mm256_load_ps(kptr + 8 * 2); + __m256 _w3 = _mm256_load_ps(kptr + 8 * 3); + __m256 _w4 = _mm256_load_ps(kptr + 8 * 4); + __m256 _w5 = _mm256_load_ps(kptr + 8 * 5); + __m256 _w6 = _mm256_load_ps(kptr + 8 * 6); + __m256 _w7 = _mm256_load_ps(kptr + 8 * 7); + __m256 _w8 = _mm256_load_ps(kptr + 8 * 8); + __m256 _w9 = _mm256_load_ps(kptr + 8 * 9); + __m256 _wa = _mm256_load_ps(kptr + 8 * 10); + __m256 _wb = _mm256_load_ps(kptr + 8 * 11); + __m256 _wc = _mm256_load_ps(kptr + 8 * 12); + __m256 _wd = _mm256_load_ps(kptr + 8 * 13); + __m256 _we = _mm256_load_ps(kptr + 8 * 14); + __m256 _wf = _mm256_load_ps(kptr + 8 * 15); + + _sum0 = _mm256_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0); + _sum1 = _mm256_fmadd_ps(_w1, _mm256_set1_ps(r0[1]), _sum1); + _sum2 = _mm256_fmadd_ps(_w2, _mm256_set1_ps(r0[2]), _sum2); + _sum3 = _mm256_fmadd_ps(_w3, _mm256_set1_ps(r0[3]), _sum3); + _sum0 = _mm256_fmadd_ps(_w4, _mm256_set1_ps(r0[4]), _sum0); + _sum1 = _mm256_fmadd_ps(_w5, _mm256_set1_ps(r0[5]), _sum1); + _sum2 = _mm256_fmadd_ps(_w6, _mm256_set1_ps(r0[6]), _sum2); + _sum3 = _mm256_fmadd_ps(_w7, _mm256_set1_ps(r0[7]), _sum3); + _sum0 = _mm256_fmadd_ps(_w8, _mm256_set1_ps(r1[0]), _sum0); + _sum1 = _mm256_fmadd_ps(_w9, _mm256_set1_ps(r1[1]), _sum1); + _sum2 = _mm256_fmadd_ps(_wa, _mm256_set1_ps(r1[2]), _sum2); + _sum3 = _mm256_fmadd_ps(_wb, _mm256_set1_ps(r1[3]), _sum3); + _sum0 = _mm256_fmadd_ps(_wc, _mm256_set1_ps(r1[4]), _sum0); + _sum1 = _mm256_fmadd_ps(_wd, _mm256_set1_ps(r1[5]), _sum1); + _sum2 = _mm256_fmadd_ps(_we, _mm256_set1_ps(r1[6]), _sum2); + _sum3 = _mm256_fmadd_ps(_wf, _mm256_set1_ps(r1[7]), _sum3); + + r0 += dilation_w * 8; + r1 += dilation_w * 8; + kptr += 128; + } + } + if (elempack == 4) + { + const float* r1 = r0 + N; + const float* r2 = r0 + N * 2; + const float* r3 = r0 + N * 3; + + for (int k = 0; k < kernel_w; k++) + { + __m256 _w0 = _mm256_load_ps(kptr + 8 * 0); + __m256 _w1 = _mm256_load_ps(kptr + 8 * 1); + __m256 _w2 = _mm256_load_ps(kptr + 8 * 2); + __m256 _w3 = _mm256_load_ps(kptr + 8 * 3); + __m256 _w4 = _mm256_load_ps(kptr + 8 * 4); + __m256 _w5 = _mm256_load_ps(kptr + 8 * 5); + __m256 _w6 = _mm256_load_ps(kptr + 8 * 6); + __m256 _w7 = _mm256_load_ps(kptr + 8 * 7); + __m256 _w8 = _mm256_load_ps(kptr + 8 * 8); + __m256 _w9 = _mm256_load_ps(kptr + 8 * 9); + __m256 _wa = _mm256_load_ps(kptr + 8 * 10); + __m256 _wb = _mm256_load_ps(kptr + 8 * 11); + __m256 _wc = _mm256_load_ps(kptr + 8 * 12); + __m256 _wd = _mm256_load_ps(kptr + 8 * 13); + __m256 _we = _mm256_load_ps(kptr + 8 * 14); + __m256 _wf = _mm256_load_ps(kptr + 8 * 15); + + _sum0 = _mm256_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0); + _sum1 = _mm256_fmadd_ps(_w1, _mm256_set1_ps(r0[1]), _sum1); + _sum2 = _mm256_fmadd_ps(_w2, _mm256_set1_ps(r0[2]), _sum2); + _sum3 = _mm256_fmadd_ps(_w3, _mm256_set1_ps(r0[3]), _sum3); + _sum0 = _mm256_fmadd_ps(_w4, _mm256_set1_ps(r1[0]), _sum0); + _sum1 = _mm256_fmadd_ps(_w5, _mm256_set1_ps(r1[1]), _sum1); + _sum2 = _mm256_fmadd_ps(_w6, _mm256_set1_ps(r1[2]), _sum2); + _sum3 = _mm256_fmadd_ps(_w7, _mm256_set1_ps(r1[3]), _sum3); + _sum0 = _mm256_fmadd_ps(_w8, _mm256_set1_ps(r2[0]), _sum0); + _sum1 = _mm256_fmadd_ps(_w9, _mm256_set1_ps(r2[1]), _sum1); + _sum2 = _mm256_fmadd_ps(_wa, _mm256_set1_ps(r2[2]), _sum2); + _sum3 = _mm256_fmadd_ps(_wb, _mm256_set1_ps(r2[3]), _sum3); + _sum0 = _mm256_fmadd_ps(_wc, _mm256_set1_ps(r3[0]), _sum0); + _sum1 = _mm256_fmadd_ps(_wd, _mm256_set1_ps(r3[1]), _sum1); + _sum2 = _mm256_fmadd_ps(_we, _mm256_set1_ps(r3[2]), _sum2); + _sum3 = _mm256_fmadd_ps(_wf, _mm256_set1_ps(r3[3]), _sum3); + + r0 += dilation_w * 4; + r1 += dilation_w * 4; + r2 += dilation_w * 4; + r3 += dilation_w * 4; + kptr += 128; + } + } + if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m256 _w0 = _mm256_load_ps(kptr + 8 * 0); + __m256 _w1 = _mm256_load_ps(kptr + 8 * 1); + __m256 _w2 = _mm256_load_ps(kptr + 8 * 2); + __m256 _w3 = _mm256_load_ps(kptr + 8 * 3); + __m256 _w4 = _mm256_load_ps(kptr + 8 * 4); + __m256 _w5 = _mm256_load_ps(kptr + 8 * 5); + __m256 _w6 = _mm256_load_ps(kptr + 8 * 6); + __m256 _w7 = _mm256_load_ps(kptr + 8 * 7); + __m256 _w8 = _mm256_load_ps(kptr + 8 * 8); + __m256 _w9 = _mm256_load_ps(kptr + 8 * 9); + __m256 _wa = _mm256_load_ps(kptr + 8 * 10); + __m256 _wb = _mm256_load_ps(kptr + 8 * 11); + __m256 _wc = _mm256_load_ps(kptr + 8 * 12); + __m256 _wd = _mm256_load_ps(kptr + 8 * 13); + __m256 _we = _mm256_load_ps(kptr + 8 * 14); + __m256 _wf = _mm256_load_ps(kptr + 8 * 15); + + _sum0 = _mm256_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0); + _sum1 = _mm256_fmadd_ps(_w1, _mm256_set1_ps(r0[N]), _sum1); + _sum2 = _mm256_fmadd_ps(_w2, _mm256_set1_ps(r0[N * 2]), _sum2); + _sum3 = _mm256_fmadd_ps(_w3, _mm256_set1_ps(r0[N * 3]), _sum3); + _sum0 = _mm256_fmadd_ps(_w4, _mm256_set1_ps(r0[N * 4]), _sum0); + _sum1 = _mm256_fmadd_ps(_w5, _mm256_set1_ps(r0[N * 5]), _sum1); + _sum2 = _mm256_fmadd_ps(_w6, _mm256_set1_ps(r0[N * 6]), _sum2); + _sum3 = _mm256_fmadd_ps(_w7, _mm256_set1_ps(r0[N * 7]), _sum3); + _sum0 = _mm256_fmadd_ps(_w8, _mm256_set1_ps(r0[N * 8]), _sum0); + _sum1 = _mm256_fmadd_ps(_w9, _mm256_set1_ps(r0[N * 9]), _sum1); + _sum2 = _mm256_fmadd_ps(_wa, _mm256_set1_ps(r0[N * 10]), _sum2); + _sum3 = _mm256_fmadd_ps(_wb, _mm256_set1_ps(r0[N * 11]), _sum3); + _sum0 = _mm256_fmadd_ps(_wc, _mm256_set1_ps(r0[N * 12]), _sum0); + _sum1 = _mm256_fmadd_ps(_wd, _mm256_set1_ps(r0[N * 13]), _sum1); + _sum2 = _mm256_fmadd_ps(_we, _mm256_set1_ps(r0[N * 14]), _sum2); + _sum3 = _mm256_fmadd_ps(_wf, _mm256_set1_ps(r0[N * 15]), _sum3); + + r0 += dilation_w; + kptr += 128; + } + } + } +#endif // __AVX512F__ + for (; q + 7 < inh; q += 8) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + if (elempack == 8) + { + for (int k = 0; k < kernel_w; k++) + { + __m256 _w0 = _mm256_load_ps(kptr); + __m256 _w1 = _mm256_load_ps(kptr + 8); + __m256 _w2 = _mm256_load_ps(kptr + 16); + __m256 _w3 = _mm256_load_ps(kptr + 24); + __m256 _w4 = _mm256_load_ps(kptr + 32); + __m256 _w5 = _mm256_load_ps(kptr + 40); + __m256 _w6 = _mm256_load_ps(kptr + 48); + __m256 _w7 = _mm256_load_ps(kptr + 56); + + _sum0 = _mm256_comp_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r0[1]), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_w2, _mm256_set1_ps(r0[2]), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_w3, _mm256_set1_ps(r0[3]), _sum3); + _sum0 = _mm256_comp_fmadd_ps(_w4, _mm256_set1_ps(r0[4]), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_w5, _mm256_set1_ps(r0[5]), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_w6, _mm256_set1_ps(r0[6]), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_w7, _mm256_set1_ps(r0[7]), _sum3); + + r0 += dilation_w * 8; + kptr += 64; + } + } + if (elempack == 4) + { + const float* r1 = r0 + N; + + for (int k = 0; k < kernel_w; k++) + { + __m256 _w0 = _mm256_load_ps(kptr); + __m256 _w1 = _mm256_load_ps(kptr + 8); + __m256 _w2 = _mm256_load_ps(kptr + 16); + __m256 _w3 = _mm256_load_ps(kptr + 24); + __m256 _w4 = _mm256_load_ps(kptr + 32); + __m256 _w5 = _mm256_load_ps(kptr + 40); + __m256 _w6 = _mm256_load_ps(kptr + 48); + __m256 _w7 = _mm256_load_ps(kptr + 56); + + _sum0 = _mm256_comp_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r0[1]), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_w2, _mm256_set1_ps(r0[2]), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_w3, _mm256_set1_ps(r0[3]), _sum3); + _sum0 = _mm256_comp_fmadd_ps(_w4, _mm256_set1_ps(r1[0]), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_w5, _mm256_set1_ps(r1[1]), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_w6, _mm256_set1_ps(r1[2]), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_w7, _mm256_set1_ps(r1[3]), _sum3); + + r0 += dilation_w * 4; + r1 += dilation_w * 4; + kptr += 64; + } + } + if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m256 _w0 = _mm256_load_ps(kptr); + __m256 _w1 = _mm256_load_ps(kptr + 8); + __m256 _w2 = _mm256_load_ps(kptr + 16); + __m256 _w3 = _mm256_load_ps(kptr + 24); + __m256 _w4 = _mm256_load_ps(kptr + 32); + __m256 _w5 = _mm256_load_ps(kptr + 40); + __m256 _w6 = _mm256_load_ps(kptr + 48); + __m256 _w7 = _mm256_load_ps(kptr + 56); + + _sum0 = _mm256_comp_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r0[N]), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_w2, _mm256_set1_ps(r0[N * 2]), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_w3, _mm256_set1_ps(r0[N * 3]), _sum3); + _sum0 = _mm256_comp_fmadd_ps(_w4, _mm256_set1_ps(r0[N * 4]), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_w5, _mm256_set1_ps(r0[N * 5]), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_w6, _mm256_set1_ps(r0[N * 6]), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_w7, _mm256_set1_ps(r0[N * 7]), _sum3); + + r0 += dilation_w; + kptr += 64; + } + } + } + for (; q + 3 < inh; q += 4) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + if (elempack == 4) + { + for (int k = 0; k < kernel_w; k++) + { + __m256 _w0 = _mm256_load_ps(kptr); + __m256 _w1 = _mm256_load_ps(kptr + 8); + __m256 _w2 = _mm256_load_ps(kptr + 16); + __m256 _w3 = _mm256_load_ps(kptr + 24); + + _sum0 = _mm256_comp_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r0[1]), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_w2, _mm256_set1_ps(r0[2]), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_w3, _mm256_set1_ps(r0[3]), _sum3); + + r0 += dilation_w * 4; + kptr += 32; + } + } + if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m256 _w0 = _mm256_load_ps(kptr); + __m256 _w1 = _mm256_load_ps(kptr + 8); + __m256 _w2 = _mm256_load_ps(kptr + 16); + __m256 _w3 = _mm256_load_ps(kptr + 24); + + _sum0 = _mm256_comp_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r0[N]), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_w2, _mm256_set1_ps(r0[N * 2]), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_w3, _mm256_set1_ps(r0[N * 3]), _sum3); + + r0 += dilation_w; + kptr += 32; + } + } + } + for (; q + 1 < inh; q += 2) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + // if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m256 _w0 = _mm256_load_ps(kptr); + __m256 _w1 = _mm256_load_ps(kptr + 8); + + _sum0 = _mm256_comp_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r0[N]), _sum1); + + r0 += dilation_w; + kptr += 16; + } + } + } + for (; q < inh; q++) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + // if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m256 _val = _mm256_set1_ps(r0[0]); + __m256 _w = _mm256_load_ps(kptr); + _sum0 = _mm256_comp_fmadd_ps(_val, _w, _sum0); + + r0 += dilation_w; + kptr += 8; + } + } + } + + _sum0 = _mm256_add_ps(_sum0, _sum1); + _sum2 = _mm256_add_ps(_sum2, _sum3); + _sum0 = _mm256_add_ps(_sum0, _sum2); + + _sum0 = activation_avx(_sum0, activation_type, activation_params); + + if (out_elempack == 8) + { + _mm256_store_ps(outptr, _sum0); + outptr += 8; + } + if (out_elempack == 4) + { + _mm_store_ps(outptr, _mm256_extractf128_ps(_sum0, 0)); + _mm_store_ps(outptr + M, _mm256_extractf128_ps(_sum0, 1)); + outptr += 4; + } + if (out_elempack == 1) + { + float sum[8]; + _mm256_storeu_ps(sum, _sum0); + + outptr[0] = sum[0]; + outptr[M] = sum[1]; + outptr[M * 2] = sum[2]; + outptr[M * 3] = sum[3]; + outptr[M * 4] = sum[4]; + outptr[M * 5] = sum[5]; + outptr[M * 6] = sum[6]; + outptr[M * 7] = sum[7]; + outptr += 1; + } + } + } + remain_outh_start += nn_outh * 8; + nn_outh = (outh - remain_outh_start) / 4; +#else // __AVX__ + nn_outh = (outh - remain_outh_start) / 4; + #pragma omp parallel for num_threads(opt.num_threads) +#endif // __AVX__ + for (int pp = 0; pp < nn_outh; pp++) + { + const int p = remain_outh_start + pp * 4; + + float* outptr = top_blob.row(p / out_elempack); + + for (int j = 0; j < outw; j++) + { + __m128 _sum0 = _mm_setzero_ps(); + __m128 _sum1 = _mm_setzero_ps(); + __m128 _sum2 = _mm_setzero_ps(); + __m128 _sum3 = _mm_setzero_ps(); + + if (bias_data_ptr) + { + _sum0 = _mm_loadu_ps(bias_data_ptr + p); + } + +#if __AVX512F__ + const float* kptr = weight_data_tm.channel(p / 16 + (p % 16) / 8 + (p % 8) / 4); +#elif __AVX__ + const float* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4); +#else + const float* kptr = weight_data_tm.channel(p / 4); +#endif + + int q = 0; +#if __AVX__ +#if __AVX512F__ + for (; q + 15 < inh; q += 16) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + if (elempack == 16) + { + for (int k = 0; k < kernel_w; k++) + { + __m128 _w0 = _mm_load_ps(kptr + 4 * 0); + __m128 _w1 = _mm_load_ps(kptr + 4 * 1); + __m128 _w2 = _mm_load_ps(kptr + 4 * 2); + __m128 _w3 = _mm_load_ps(kptr + 4 * 3); + __m128 _w4 = _mm_load_ps(kptr + 4 * 4); + __m128 _w5 = _mm_load_ps(kptr + 4 * 5); + __m128 _w6 = _mm_load_ps(kptr + 4 * 6); + __m128 _w7 = _mm_load_ps(kptr + 4 * 7); + __m128 _w8 = _mm_load_ps(kptr + 4 * 8); + __m128 _w9 = _mm_load_ps(kptr + 4 * 9); + __m128 _wa = _mm_load_ps(kptr + 4 * 10); + __m128 _wb = _mm_load_ps(kptr + 4 * 11); + __m128 _wc = _mm_load_ps(kptr + 4 * 12); + __m128 _wd = _mm_load_ps(kptr + 4 * 13); + __m128 _we = _mm_load_ps(kptr + 4 * 14); + __m128 _wf = _mm_load_ps(kptr + 4 * 15); + + _sum0 = _mm_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0); + _sum1 = _mm_fmadd_ps(_w1, _mm_set1_ps(r0[1]), _sum1); + _sum2 = _mm_fmadd_ps(_w2, _mm_set1_ps(r0[2]), _sum2); + _sum3 = _mm_fmadd_ps(_w3, _mm_set1_ps(r0[3]), _sum3); + _sum0 = _mm_fmadd_ps(_w4, _mm_set1_ps(r0[4]), _sum0); + _sum1 = _mm_fmadd_ps(_w5, _mm_set1_ps(r0[5]), _sum1); + _sum2 = _mm_fmadd_ps(_w6, _mm_set1_ps(r0[6]), _sum2); + _sum3 = _mm_fmadd_ps(_w7, _mm_set1_ps(r0[7]), _sum3); + _sum0 = _mm_fmadd_ps(_w8, _mm_set1_ps(r0[8]), _sum0); + _sum1 = _mm_fmadd_ps(_w9, _mm_set1_ps(r0[9]), _sum1); + _sum2 = _mm_fmadd_ps(_wa, _mm_set1_ps(r0[10]), _sum2); + _sum3 = _mm_fmadd_ps(_wb, _mm_set1_ps(r0[11]), _sum3); + _sum0 = _mm_fmadd_ps(_wc, _mm_set1_ps(r0[12]), _sum0); + _sum1 = _mm_fmadd_ps(_wd, _mm_set1_ps(r0[13]), _sum1); + _sum2 = _mm_fmadd_ps(_we, _mm_set1_ps(r0[14]), _sum2); + _sum3 = _mm_fmadd_ps(_wf, _mm_set1_ps(r0[15]), _sum3); + + r0 += dilation_w * 16; + kptr += 64; + } + } + if (elempack == 8) + { + const float* r1 = r0 + N; + + for (int k = 0; k < kernel_w; k++) + { + __m128 _w0 = _mm_load_ps(kptr + 4 * 0); + __m128 _w1 = _mm_load_ps(kptr + 4 * 1); + __m128 _w2 = _mm_load_ps(kptr + 4 * 2); + __m128 _w3 = _mm_load_ps(kptr + 4 * 3); + __m128 _w4 = _mm_load_ps(kptr + 4 * 4); + __m128 _w5 = _mm_load_ps(kptr + 4 * 5); + __m128 _w6 = _mm_load_ps(kptr + 4 * 6); + __m128 _w7 = _mm_load_ps(kptr + 4 * 7); + __m128 _w8 = _mm_load_ps(kptr + 4 * 8); + __m128 _w9 = _mm_load_ps(kptr + 4 * 9); + __m128 _wa = _mm_load_ps(kptr + 4 * 10); + __m128 _wb = _mm_load_ps(kptr + 4 * 11); + __m128 _wc = _mm_load_ps(kptr + 4 * 12); + __m128 _wd = _mm_load_ps(kptr + 4 * 13); + __m128 _we = _mm_load_ps(kptr + 4 * 14); + __m128 _wf = _mm_load_ps(kptr + 4 * 15); + + _sum0 = _mm_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0); + _sum1 = _mm_fmadd_ps(_w1, _mm_set1_ps(r0[1]), _sum1); + _sum2 = _mm_fmadd_ps(_w2, _mm_set1_ps(r0[2]), _sum2); + _sum3 = _mm_fmadd_ps(_w3, _mm_set1_ps(r0[3]), _sum3); + _sum0 = _mm_fmadd_ps(_w4, _mm_set1_ps(r0[4]), _sum0); + _sum1 = _mm_fmadd_ps(_w5, _mm_set1_ps(r0[5]), _sum1); + _sum2 = _mm_fmadd_ps(_w6, _mm_set1_ps(r0[6]), _sum2); + _sum3 = _mm_fmadd_ps(_w7, _mm_set1_ps(r0[7]), _sum3); + _sum0 = _mm_fmadd_ps(_w8, _mm_set1_ps(r1[0]), _sum0); + _sum1 = _mm_fmadd_ps(_w9, _mm_set1_ps(r1[1]), _sum1); + _sum2 = _mm_fmadd_ps(_wa, _mm_set1_ps(r1[2]), _sum2); + _sum3 = _mm_fmadd_ps(_wb, _mm_set1_ps(r1[3]), _sum3); + _sum0 = _mm_fmadd_ps(_wc, _mm_set1_ps(r1[4]), _sum0); + _sum1 = _mm_fmadd_ps(_wd, _mm_set1_ps(r1[5]), _sum1); + _sum2 = _mm_fmadd_ps(_we, _mm_set1_ps(r1[6]), _sum2); + _sum3 = _mm_fmadd_ps(_wf, _mm_set1_ps(r1[7]), _sum3); + + r0 += dilation_w * 8; + r1 += dilation_w * 8; + kptr += 64; + } + } + if (elempack == 4) + { + const float* r1 = r0 + N; + const float* r2 = r0 + N * 2; + const float* r3 = r0 + N * 3; + + for (int k = 0; k < kernel_w; k++) + { + __m128 _w0 = _mm_load_ps(kptr + 4 * 0); + __m128 _w1 = _mm_load_ps(kptr + 4 * 1); + __m128 _w2 = _mm_load_ps(kptr + 4 * 2); + __m128 _w3 = _mm_load_ps(kptr + 4 * 3); + __m128 _w4 = _mm_load_ps(kptr + 4 * 4); + __m128 _w5 = _mm_load_ps(kptr + 4 * 5); + __m128 _w6 = _mm_load_ps(kptr + 4 * 6); + __m128 _w7 = _mm_load_ps(kptr + 4 * 7); + __m128 _w8 = _mm_load_ps(kptr + 4 * 8); + __m128 _w9 = _mm_load_ps(kptr + 4 * 9); + __m128 _wa = _mm_load_ps(kptr + 4 * 10); + __m128 _wb = _mm_load_ps(kptr + 4 * 11); + __m128 _wc = _mm_load_ps(kptr + 4 * 12); + __m128 _wd = _mm_load_ps(kptr + 4 * 13); + __m128 _we = _mm_load_ps(kptr + 4 * 14); + __m128 _wf = _mm_load_ps(kptr + 4 * 15); + + _sum0 = _mm_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0); + _sum1 = _mm_fmadd_ps(_w1, _mm_set1_ps(r0[1]), _sum1); + _sum2 = _mm_fmadd_ps(_w2, _mm_set1_ps(r0[2]), _sum2); + _sum3 = _mm_fmadd_ps(_w3, _mm_set1_ps(r0[3]), _sum3); + _sum0 = _mm_fmadd_ps(_w4, _mm_set1_ps(r1[0]), _sum0); + _sum1 = _mm_fmadd_ps(_w5, _mm_set1_ps(r1[1]), _sum1); + _sum2 = _mm_fmadd_ps(_w6, _mm_set1_ps(r1[2]), _sum2); + _sum3 = _mm_fmadd_ps(_w7, _mm_set1_ps(r1[3]), _sum3); + _sum0 = _mm_fmadd_ps(_w8, _mm_set1_ps(r2[0]), _sum0); + _sum1 = _mm_fmadd_ps(_w9, _mm_set1_ps(r2[1]), _sum1); + _sum2 = _mm_fmadd_ps(_wa, _mm_set1_ps(r2[2]), _sum2); + _sum3 = _mm_fmadd_ps(_wb, _mm_set1_ps(r2[3]), _sum3); + _sum0 = _mm_fmadd_ps(_wc, _mm_set1_ps(r3[0]), _sum0); + _sum1 = _mm_fmadd_ps(_wd, _mm_set1_ps(r3[1]), _sum1); + _sum2 = _mm_fmadd_ps(_we, _mm_set1_ps(r3[2]), _sum2); + _sum3 = _mm_fmadd_ps(_wf, _mm_set1_ps(r3[3]), _sum3); + + r0 += dilation_w * 4; + r1 += dilation_w * 4; + r2 += dilation_w * 4; + r3 += dilation_w * 4; + kptr += 64; + } + } + if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m128 _w0 = _mm_load_ps(kptr + 4 * 0); + __m128 _w1 = _mm_load_ps(kptr + 4 * 1); + __m128 _w2 = _mm_load_ps(kptr + 4 * 2); + __m128 _w3 = _mm_load_ps(kptr + 4 * 3); + __m128 _w4 = _mm_load_ps(kptr + 4 * 4); + __m128 _w5 = _mm_load_ps(kptr + 4 * 5); + __m128 _w6 = _mm_load_ps(kptr + 4 * 6); + __m128 _w7 = _mm_load_ps(kptr + 4 * 7); + __m128 _w8 = _mm_load_ps(kptr + 4 * 8); + __m128 _w9 = _mm_load_ps(kptr + 4 * 9); + __m128 _wa = _mm_load_ps(kptr + 4 * 10); + __m128 _wb = _mm_load_ps(kptr + 4 * 11); + __m128 _wc = _mm_load_ps(kptr + 4 * 12); + __m128 _wd = _mm_load_ps(kptr + 4 * 13); + __m128 _we = _mm_load_ps(kptr + 4 * 14); + __m128 _wf = _mm_load_ps(kptr + 4 * 15); + + _sum0 = _mm_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0); + _sum1 = _mm_fmadd_ps(_w1, _mm_set1_ps(r0[N]), _sum1); + _sum2 = _mm_fmadd_ps(_w2, _mm_set1_ps(r0[N * 2]), _sum2); + _sum3 = _mm_fmadd_ps(_w3, _mm_set1_ps(r0[N * 3]), _sum3); + _sum0 = _mm_fmadd_ps(_w4, _mm_set1_ps(r0[N * 4]), _sum0); + _sum1 = _mm_fmadd_ps(_w5, _mm_set1_ps(r0[N * 5]), _sum1); + _sum2 = _mm_fmadd_ps(_w6, _mm_set1_ps(r0[N * 6]), _sum2); + _sum3 = _mm_fmadd_ps(_w7, _mm_set1_ps(r0[N * 7]), _sum3); + _sum0 = _mm_fmadd_ps(_w8, _mm_set1_ps(r0[N * 8]), _sum0); + _sum1 = _mm_fmadd_ps(_w9, _mm_set1_ps(r0[N * 9]), _sum1); + _sum2 = _mm_fmadd_ps(_wa, _mm_set1_ps(r0[N * 10]), _sum2); + _sum3 = _mm_fmadd_ps(_wb, _mm_set1_ps(r0[N * 11]), _sum3); + _sum0 = _mm_fmadd_ps(_wc, _mm_set1_ps(r0[N * 12]), _sum0); + _sum1 = _mm_fmadd_ps(_wd, _mm_set1_ps(r0[N * 13]), _sum1); + _sum2 = _mm_fmadd_ps(_we, _mm_set1_ps(r0[N * 14]), _sum2); + _sum3 = _mm_fmadd_ps(_wf, _mm_set1_ps(r0[N * 15]), _sum3); + + r0 += dilation_w; + kptr += 64; + } + } + } +#endif // __AVX512F__ + for (; q + 7 < inh; q += 8) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + if (elempack == 8) + { + for (int k = 0; k < kernel_w; k++) + { + __m128 _w0 = _mm_load_ps(kptr); + __m128 _w1 = _mm_load_ps(kptr + 4); + __m128 _w2 = _mm_load_ps(kptr + 8); + __m128 _w3 = _mm_load_ps(kptr + 12); + __m128 _w4 = _mm_load_ps(kptr + 16); + __m128 _w5 = _mm_load_ps(kptr + 20); + __m128 _w6 = _mm_load_ps(kptr + 24); + __m128 _w7 = _mm_load_ps(kptr + 28); + + _sum0 = _mm_comp_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0); + _sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r0[1]), _sum1); + _sum2 = _mm_comp_fmadd_ps(_w2, _mm_set1_ps(r0[2]), _sum2); + _sum3 = _mm_comp_fmadd_ps(_w3, _mm_set1_ps(r0[3]), _sum3); + _sum0 = _mm_comp_fmadd_ps(_w4, _mm_set1_ps(r0[4]), _sum0); + _sum1 = _mm_comp_fmadd_ps(_w5, _mm_set1_ps(r0[5]), _sum1); + _sum2 = _mm_comp_fmadd_ps(_w6, _mm_set1_ps(r0[6]), _sum2); + _sum3 = _mm_comp_fmadd_ps(_w7, _mm_set1_ps(r0[7]), _sum3); + + r0 += dilation_w * 8; + kptr += 32; + } + } + if (elempack == 4) + { + const float* r1 = r0 + N; + + for (int k = 0; k < kernel_w; k++) + { + __m128 _w0 = _mm_load_ps(kptr); + __m128 _w1 = _mm_load_ps(kptr + 4); + __m128 _w2 = _mm_load_ps(kptr + 8); + __m128 _w3 = _mm_load_ps(kptr + 12); + __m128 _w4 = _mm_load_ps(kptr + 16); + __m128 _w5 = _mm_load_ps(kptr + 20); + __m128 _w6 = _mm_load_ps(kptr + 24); + __m128 _w7 = _mm_load_ps(kptr + 28); + + _sum0 = _mm_comp_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0); + _sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r0[1]), _sum1); + _sum2 = _mm_comp_fmadd_ps(_w2, _mm_set1_ps(r0[2]), _sum2); + _sum3 = _mm_comp_fmadd_ps(_w3, _mm_set1_ps(r0[3]), _sum3); + _sum0 = _mm_comp_fmadd_ps(_w4, _mm_set1_ps(r1[0]), _sum0); + _sum1 = _mm_comp_fmadd_ps(_w5, _mm_set1_ps(r1[1]), _sum1); + _sum2 = _mm_comp_fmadd_ps(_w6, _mm_set1_ps(r1[2]), _sum2); + _sum3 = _mm_comp_fmadd_ps(_w7, _mm_set1_ps(r1[3]), _sum3); + + r0 += dilation_w * 4; + r1 += dilation_w * 4; + kptr += 32; + } + } + if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m128 _w0 = _mm_load_ps(kptr); + __m128 _w1 = _mm_load_ps(kptr + 4); + __m128 _w2 = _mm_load_ps(kptr + 8); + __m128 _w3 = _mm_load_ps(kptr + 12); + __m128 _w4 = _mm_load_ps(kptr + 16); + __m128 _w5 = _mm_load_ps(kptr + 20); + __m128 _w6 = _mm_load_ps(kptr + 24); + __m128 _w7 = _mm_load_ps(kptr + 28); + + _sum0 = _mm_comp_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0); + _sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r0[N]), _sum1); + _sum2 = _mm_comp_fmadd_ps(_w2, _mm_set1_ps(r0[N * 2]), _sum2); + _sum3 = _mm_comp_fmadd_ps(_w3, _mm_set1_ps(r0[N * 3]), _sum3); + _sum0 = _mm_comp_fmadd_ps(_w4, _mm_set1_ps(r0[N * 4]), _sum0); + _sum1 = _mm_comp_fmadd_ps(_w5, _mm_set1_ps(r0[N * 5]), _sum1); + _sum2 = _mm_comp_fmadd_ps(_w6, _mm_set1_ps(r0[N * 6]), _sum2); + _sum3 = _mm_comp_fmadd_ps(_w7, _mm_set1_ps(r0[N * 7]), _sum3); + + r0 += dilation_w; + kptr += 32; + } + } + } +#endif // __AVX__ + for (; q + 3 < inh; q += 4) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + if (elempack == 4) + { + for (int k = 0; k < kernel_w; k++) + { + __m128 _w0 = _mm_load_ps(kptr); + __m128 _w1 = _mm_load_ps(kptr + 4); + __m128 _w2 = _mm_load_ps(kptr + 8); + __m128 _w3 = _mm_load_ps(kptr + 12); + + _sum0 = _mm_comp_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0); + _sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r0[1]), _sum1); + _sum2 = _mm_comp_fmadd_ps(_w2, _mm_set1_ps(r0[2]), _sum2); + _sum3 = _mm_comp_fmadd_ps(_w3, _mm_set1_ps(r0[3]), _sum3); + + r0 += dilation_w * 4; + kptr += 16; + } + } + if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m128 _w0 = _mm_load_ps(kptr); + __m128 _w1 = _mm_load_ps(kptr + 4); + __m128 _w2 = _mm_load_ps(kptr + 8); + __m128 _w3 = _mm_load_ps(kptr + 12); + + _sum0 = _mm_comp_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0); + _sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r0[N]), _sum1); + _sum2 = _mm_comp_fmadd_ps(_w2, _mm_set1_ps(r0[N * 2]), _sum2); + _sum3 = _mm_comp_fmadd_ps(_w3, _mm_set1_ps(r0[N * 3]), _sum3); + + r0 += dilation_w; + kptr += 16; + } + } + } + for (; q + 1 < inh; q += 2) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + // if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m128 _w0 = _mm_load_ps(kptr); + __m128 _w1 = _mm_load_ps(kptr + 4); + + _sum0 = _mm_comp_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0); + _sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r0[N]), _sum1); + + r0 += dilation_w; + kptr += 8; + } + } + } + for (; q < inh; q++) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + // if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m128 _val = _mm_set1_ps(r0[0]); + __m128 _w = _mm_load_ps(kptr); + _sum0 = _mm_comp_fmadd_ps(_val, _w, _sum0); + + r0 += dilation_w; + kptr += 4; + } + } + } + + _sum0 = _mm_add_ps(_sum0, _sum1); + _sum2 = _mm_add_ps(_sum2, _sum3); + _sum0 = _mm_add_ps(_sum0, _sum2); + + _sum0 = activation_sse(_sum0, activation_type, activation_params); + + if (out_elempack == 4) + { + _mm_storeu_ps(outptr, _sum0); + outptr += 4; + } + if (out_elempack == 1) + { + float sum[4]; + _mm_storeu_ps(sum, _sum0); + + outptr[0] = sum[0]; + outptr[M] = sum[1]; + outptr[M * 2] = sum[2]; + outptr[M * 3] = sum[3]; + outptr += 1; + } + } + } + remain_outh_start += nn_outh * 4; + nn_outh = (outh - remain_outh_start) / 2; +#else // __SSE2__ + nn_outh = (outh - remain_outh_start) / 2; + #pragma omp parallel for num_threads(opt.num_threads) +#endif // __SSE2__ + for (int pp = 0; pp < nn_outh; pp++) + { + const int p = remain_outh_start + pp * 2; + + float* outptr0 = top_blob.row(p); + float* outptr1 = top_blob.row(p + 1); + + for (int j = 0; j < outw; j++) + { + float sum0 = 0.f; + float sum1 = 0.f; + + if (bias_data_ptr) + { + sum0 = bias_data_ptr[p]; + sum1 = bias_data_ptr[p + 1]; + } + +#if __AVX512F__ + const float* kptr = weight_data_tm.channel(p / 16 + (p % 16) / 8 + (p % 8) / 4 + (p % 4) / 2); +#elif __AVX__ + const float* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4 + (p % 4) / 2); +#elif __SSE2__ + const float* kptr = weight_data_tm.channel(p / 4 + (p % 4) / 2); +#else + const float* kptr = weight_data_tm.channel(p / 2); +#endif + + int q = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _sum0_avx512 = _mm512_setzero_ps(); + __m512 _sum1_avx512 = _mm512_setzero_ps(); + for (; q + 15 < inh; q += 16) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + if (elempack == 16) + { + for (int k = 0; k < kernel_w; k++) + { + __m512 _r0 = _mm512_load_ps(r0); + __m512 _w0 = _mm512_load_ps(kptr); + __m512 _w1 = _mm512_load_ps(kptr + 16); + _sum0_avx512 = _mm512_fmadd_ps(_r0, _w0, _sum0_avx512); + _sum1_avx512 = _mm512_fmadd_ps(_r0, _w1, _sum1_avx512); + + r0 += dilation_w * 16; + kptr += 32; + } + } + if (elempack == 8) + { + const float* r1 = r0 + N; + + for (int k = 0; k < kernel_w; k++) + { + __m512 _r0 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_load_ps(r0)), _mm256_load_ps(r1), 1); + __m512 _w0 = _mm512_load_ps(kptr); + __m512 _w1 = _mm512_load_ps(kptr + 16); + _sum0_avx512 = _mm512_fmadd_ps(_r0, _w0, _sum0_avx512); + _sum1_avx512 = _mm512_fmadd_ps(_r0, _w1, _sum1_avx512); + + r0 += dilation_w * 8; + r1 += dilation_w * 8; + kptr += 32; + } + } + if (elempack == 4) + { + const float* r1 = r0 + N; + const float* r2 = r0 + N * 2; + const float* r3 = r0 + N * 3; + + for (int k = 0; k < kernel_w; k++) + { + __m512 _r0 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(r0)), _mm_load_ps(r1), 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(r2)), _mm_load_ps(r3), 1), 1); + __m512 _w0 = _mm512_load_ps(kptr); + __m512 _w1 = _mm512_load_ps(kptr + 16); + _sum0_avx512 = _mm512_fmadd_ps(_r0, _w0, _sum0_avx512); + _sum1_avx512 = _mm512_fmadd_ps(_r0, _w1, _sum1_avx512); + + r0 += dilation_w * 4; + r1 += dilation_w * 4; + r2 += dilation_w * 4; + r3 += dilation_w * 4; + kptr += 32; + } + } + if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m512 _r0 = _mm512_set_ps(r0[N * 15], r0[N * 14], r0[N * 13], r0[N * 12], r0[N * 11], r0[N * 10], r0[N * 9], r0[N * 8], r0[N * 7], r0[N * 6], r0[N * 5], r0[N * 4], r0[N * 3], r0[N * 2], r0[N], r0[0]); + __m512 _w0 = _mm512_load_ps(kptr); + __m512 _w1 = _mm512_load_ps(kptr + 16); + _sum0_avx512 = _mm512_fmadd_ps(_r0, _w0, _sum0_avx512); + _sum1_avx512 = _mm512_fmadd_ps(_r0, _w1, _sum1_avx512); + + r0 += dilation_w; + kptr += 32; + } + } + } + sum0 += _mm512_comp_reduce_add_ps(_sum0_avx512); + sum1 += _mm512_comp_reduce_add_ps(_sum1_avx512); +#endif // __AVX512F__ + __m256 _sum0_avx = _mm256_setzero_ps(); + __m256 _sum1_avx = _mm256_setzero_ps(); + for (; q + 7 < inh; q += 8) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + if (elempack == 8) + { + for (int k = 0; k < kernel_w; k++) + { + __m256 _r0 = _mm256_load_ps(r0); + __m256 _w0 = _mm256_load_ps(kptr); + __m256 _w1 = _mm256_load_ps(kptr + 8); + _sum0_avx = _mm256_comp_fmadd_ps(_r0, _w0, _sum0_avx); + _sum1_avx = _mm256_comp_fmadd_ps(_r0, _w1, _sum1_avx); + + r0 += dilation_w * 8; + kptr += 16; + } + } + if (elempack == 4) + { + const float* r1 = r0 + N; + + for (int k = 0; k < kernel_w; k++) + { + __m256 _r0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(r0)), _mm_load_ps(r1), 1); + __m256 _w0 = _mm256_load_ps(kptr); + __m256 _w1 = _mm256_load_ps(kptr + 8); + _sum0_avx = _mm256_comp_fmadd_ps(_r0, _w0, _sum0_avx); + _sum1_avx = _mm256_comp_fmadd_ps(_r0, _w1, _sum1_avx); + + r0 += dilation_w * 4; + r1 += dilation_w * 4; + kptr += 16; + } + } + if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m256 _r0 = _mm256_set_ps(r0[N * 7], r0[N * 6], r0[N * 5], r0[N * 4], r0[N * 3], r0[N * 2], r0[N], r0[0]); + __m256 _w0 = _mm256_load_ps(kptr); + __m256 _w1 = _mm256_load_ps(kptr + 8); + _sum0_avx = _mm256_comp_fmadd_ps(_r0, _w0, _sum0_avx); + _sum1_avx = _mm256_comp_fmadd_ps(_r0, _w1, _sum1_avx); + + r0 += dilation_w; + kptr += 16; + } + } + } + sum0 += _mm256_reduce_add_ps(_sum0_avx); + sum1 += _mm256_reduce_add_ps(_sum1_avx); +#endif // __AVX__ + __m128 _sum0 = _mm_setzero_ps(); + __m128 _sum1 = _mm_setzero_ps(); + for (; q + 3 < inh; q += 4) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + if (elempack == 4) + { + for (int k = 0; k < kernel_w; k++) + { + __m128 _r0 = _mm_load_ps(r0); + __m128 _w0 = _mm_load_ps(kptr); + __m128 _w1 = _mm_load_ps(kptr + 4); + _sum0 = _mm_comp_fmadd_ps(_r0, _w0, _sum0); + _sum1 = _mm_comp_fmadd_ps(_r0, _w1, _sum1); + + r0 += dilation_w * 4; + kptr += 8; + } + } + if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m128 _r0 = _mm_set_ps(r0[N * 3], r0[N * 2], r0[N], r0[0]); + __m128 _w0 = _mm_load_ps(kptr); + __m128 _w1 = _mm_load_ps(kptr + 4); + _sum0 = _mm_comp_fmadd_ps(_r0, _w0, _sum0); + _sum1 = _mm_comp_fmadd_ps(_r0, _w1, _sum1); + + r0 += dilation_w; + kptr += 8; + } + } + } + sum0 += _mm_reduce_add_ps(_sum0); + sum1 += _mm_reduce_add_ps(_sum1); +#endif // __SSE2__ + for (; q + 1 < inh; q += 2) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + // if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + sum0 += r0[0] * kptr[0]; + sum1 += r0[0] * kptr[1]; + sum0 += r0[N] * kptr[2]; + sum1 += r0[N] * kptr[3]; + + r0 += dilation_w; + kptr += 4; + } + } + } + for (; q < inh; q++) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + // if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + float val = r0[0]; + sum0 += val * kptr[0]; + sum1 += val * kptr[1]; + + r0 += dilation_w; + kptr += 2; + } + } + } + + sum0 = activation_ss(sum0, activation_type, activation_params); + sum1 = activation_ss(sum1, activation_type, activation_params); + + outptr0[0] = sum0; + outptr1[0] = sum1; + outptr0 += 1; + outptr1 += 1; + } + } + remain_outh_start += nn_outh * 2; + for (int p = remain_outh_start; p < outh; p++) + { + float* outptr = top_blob.row(p); + + for (int j = 0; j < outw; j++) + { + float sum = 0.f; + + if (bias_data_ptr) + { + sum = bias_data_ptr[p]; + } + +#if __AVX512F__ + const float* kptr = weight_data_tm.channel(p / 16 + (p % 16) / 8 + (p % 8) / 4 + (p % 4) / 2 + p % 2); +#elif __AVX__ + const float* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4 + (p % 4) / 2 + p % 2); +#elif __SSE2__ + const float* kptr = weight_data_tm.channel(p / 4 + (p % 4) / 2 + p % 2); +#else + const float* kptr = weight_data_tm.channel(p / 2 + p % 2); +#endif + + int q = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _sum_avx512 = _mm512_setzero_ps(); + for (; q + 15 < inh; q += 16) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + if (elempack == 16) + { + for (int k = 0; k < kernel_w; k++) + { + __m512 _r0 = _mm512_load_ps(r0); + __m512 _w = _mm512_load_ps(kptr); + _sum_avx512 = _mm512_fmadd_ps(_r0, _w, _sum_avx512); + + r0 += dilation_w * 16; + kptr += 16; + } + } + if (elempack == 8) + { + const float* r1 = r0 + N; + + for (int k = 0; k < kernel_w; k++) + { + __m512 _r0 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_load_ps(r0)), _mm256_load_ps(r1), 1); + __m512 _w = _mm512_load_ps(kptr); + _sum_avx512 = _mm512_fmadd_ps(_r0, _w, _sum_avx512); + + r0 += dilation_w * 8; + r1 += dilation_w * 8; + kptr += 16; + } + } + if (elempack == 4) + { + const float* r1 = r0 + N; + const float* r2 = r0 + N * 2; + const float* r3 = r0 + N * 3; + + for (int k = 0; k < kernel_w; k++) + { + __m512 _r0 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(r0)), _mm_load_ps(r1), 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(r2)), _mm_load_ps(r3), 1), 1); + __m512 _w = _mm512_load_ps(kptr); + _sum_avx512 = _mm512_fmadd_ps(_r0, _w, _sum_avx512); + + r0 += dilation_w * 4; + r1 += dilation_w * 4; + r2 += dilation_w * 4; + r3 += dilation_w * 4; + kptr += 16; + } + } + if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m512 _r0 = _mm512_set_ps(r0[N * 15], r0[N * 14], r0[N * 13], r0[N * 12], r0[N * 11], r0[N * 10], r0[N * 9], r0[N * 8], r0[N * 7], r0[N * 6], r0[N * 5], r0[N * 4], r0[N * 3], r0[N * 2], r0[N], r0[0]); + __m512 _w = _mm512_load_ps(kptr); + _sum_avx512 = _mm512_fmadd_ps(_r0, _w, _sum_avx512); + + r0 += dilation_w; + kptr += 16; + } + } + } + sum += _mm512_comp_reduce_add_ps(_sum_avx512); +#endif // __AVX512F__ + __m256 _sum_avx = _mm256_setzero_ps(); + for (; q + 7 < inh; q += 8) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + if (elempack == 8) + { + for (int k = 0; k < kernel_w; k++) + { + __m256 _r0 = _mm256_load_ps(r0); + __m256 _w = _mm256_load_ps(kptr); + _sum_avx = _mm256_comp_fmadd_ps(_r0, _w, _sum_avx); + + r0 += dilation_w * 8; + kptr += 8; + } + } + if (elempack == 4) + { + const float* r1 = r0 + N; + + for (int k = 0; k < kernel_w; k++) + { + __m256 _r0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(r0)), _mm_load_ps(r1), 1); + __m256 _w = _mm256_load_ps(kptr); + _sum_avx = _mm256_comp_fmadd_ps(_r0, _w, _sum_avx); + + r0 += dilation_w * 4; + r1 += dilation_w * 4; + kptr += 8; + } + } + if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m256 _r0 = _mm256_set_ps(r0[N * 7], r0[N * 6], r0[N * 5], r0[N * 4], r0[N * 3], r0[N * 2], r0[N], r0[0]); + __m256 _w = _mm256_load_ps(kptr); + _sum_avx = _mm256_comp_fmadd_ps(_r0, _w, _sum_avx); + + r0 += dilation_w; + kptr += 8; + } + } + } + sum += _mm256_reduce_add_ps(_sum_avx); +#endif // __AVX__ + __m128 _sum = _mm_setzero_ps(); + for (; q + 3 < inh; q += 4) + { + const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack; + + if (elempack == 4) + { + for (int k = 0; k < kernel_w; k++) + { + __m128 _r0 = _mm_load_ps(r0); + __m128 _w = _mm_load_ps(kptr); + _sum = _mm_comp_fmadd_ps(_r0, _w, _sum); + + r0 += dilation_w * 4; + kptr += 4; + } + } + if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + __m128 _r0 = _mm_set_ps(r0[N * 3], r0[N * 2], r0[N], r0[0]); + __m128 _w = _mm_load_ps(kptr); + _sum = _mm_comp_fmadd_ps(_r0, _w, _sum); + + r0 += dilation_w; + kptr += 4; + } + } + } + sum += _mm_reduce_add_ps(_sum); +#endif // __SSE2__ + for (; q + 1 < inh; q += 2) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + // if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + sum += r0[0] * kptr[0]; + sum += r0[N] * kptr[1]; + + r0 += dilation_w; + kptr += 2; + } + } + } + for (; q < inh; q++) + { + const float* r0 = bottom_blob.row(q) + j * stride_w; + + // if (elempack == 1) + { + for (int k = 0; k < kernel_w; k++) + { + float val = r0[0]; + sum += val * kptr[0]; + + r0 += dilation_w; + kptr += 1; + } + } + } + + sum = activation_ss(sum, activation_type, activation_params); + + outptr[0] = sum; + outptr += 1; + } + } +} diff --git a/src/layer/x86/convolution1d_x86.cpp b/src/layer/x86/convolution1d_x86.cpp index 950226e6b..e7df16b83 100644 --- a/src/layer/x86/convolution1d_x86.cpp +++ b/src/layer/x86/convolution1d_x86.cpp @@ -25,6 +25,8 @@ namespace ncnn { +#include "convolution1d_packed.h" + Convolution1D_x86::Convolution1D_x86() { #if __SSE2__ @@ -32,59 +34,14 @@ Convolution1D_x86::Convolution1D_x86() #endif // __SSE2__ } -int Convolution1D_x86::create_pipeline(const Option& opt) +int Convolution1D_x86::create_pipeline(const Option& /*opt*/) { if (dynamic_weight) return 0; int num_input = weight_data_size / kernel_w / num_output; - int elempack = 1; - int out_elempack = 1; - -#if __SSE2__ - if (opt.use_packing_layout) - { -#if __AVX__ - elempack = num_input % 8 == 0 ? 8 : num_input % 4 == 0 ? 4 : 1; - out_elempack = num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; -#else - elempack = num_input % 4 == 0 ? 4 : 1; - out_elempack = num_output % 4 == 0 ? 4 : 1; -#endif - } -#endif // __SSE2__ - - // src = kw-inch-outch - // dst = pb-pa-kw-inch/pa-outch/pb - { - Mat weight_data_r2 = weight_data.reshape(kernel_w, num_input, num_output); - - weight_data_packed.create(kernel_w, num_input / elempack, num_output / out_elempack, (size_t)4u * elempack * out_elempack, elempack * out_elempack); - - for (int q = 0; q + (out_elempack - 1) < num_output; q += out_elempack) - { - float* g00 = weight_data_packed.channel(q / out_elempack); - - for (int p = 0; p + (elempack - 1) < num_input; p += elempack) - { - for (int k = 0; k < kernel_w; k++) - { - for (int i = 0; i < elempack; i++) - { - for (int j = 0; j < out_elempack; j++) - { - const float* k00 = weight_data_r2.channel(q + j).row(p + i); - - g00[0] = k00[k]; - - g00++; - } - } - } - } - } - } + convolution1d_transform_kernel_packed(weight_data, weight_data_tm, num_input, num_output, kernel_w); return 0; } @@ -97,25 +54,9 @@ int Convolution1D_x86::destroy_pipeline(const Option& /*opt*/) int Convolution1D_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { int w = bottom_blob.w; - int h = bottom_blob.h; size_t elemsize = bottom_blob.elemsize; int elempack = bottom_blob.elempack; -#if __AVX512F__ - if (elempack == 16) - { - Mat tmp; - convert_packing(bottom_blob, tmp, 8, opt); - - Mat tmpout; - forward(tmp, tmpout, opt); - - convert_packing(tmpout, top_blob, 16, opt); - - return 0; - } -#endif // __AVX512F__ - const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; Mat bottom_blob_bordered; @@ -124,13 +65,14 @@ int Convolution1D_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Opti return -100; w = bottom_blob_bordered.w; - h = bottom_blob_bordered.h; int out_elempack = 1; #if __SSE2__ if (opt.use_packing_layout) { -#if __AVX__ +#if __AVX512F__ + out_elempack = num_output % 16 == 0 ? 16 : num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; +#elif __AVX__ out_elempack = num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; #else out_elempack = num_output % 4 == 0 ? 4 : 1; @@ -146,457 +88,7 @@ int Convolution1D_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Opti if (top_blob.empty()) return -100; -#if __SSE2__ -#if __AVX__ - if (elempack == 8 && out_elempack == 8) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - float* outptr = top_blob.row(p); - - for (int j = 0; j < outw; j++) - { - __m256 _sum = _mm256_set1_ps(0.f); - - if (bias_term) - { - _sum = _mm256_loadu_ps(((const float*)bias_data) + p * 8); - } - - const float* kptr = weight_data_packed.channel(p); - - for (int q = 0; q < h; q++) - { - const float* sptr = bottom_blob_bordered.row(q) + j * stride_w * 8; - - for (int k = 0; k < kernel_w; k++) - { - __m256 _val0 = _mm256_broadcast_ss(sptr); - __m256 _val1 = _mm256_broadcast_ss(sptr + 1); - __m256 _val2 = _mm256_broadcast_ss(sptr + 2); - __m256 _val3 = _mm256_broadcast_ss(sptr + 3); - __m256 _val4 = _mm256_broadcast_ss(sptr + 4); - __m256 _val5 = _mm256_broadcast_ss(sptr + 5); - __m256 _val6 = _mm256_broadcast_ss(sptr + 6); - __m256 _val7 = _mm256_broadcast_ss(sptr + 7); - - __m256 _w0 = _mm256_loadu_ps(kptr); - __m256 _w1 = _mm256_loadu_ps(kptr + 8); - __m256 _w2 = _mm256_loadu_ps(kptr + 16); - __m256 _w3 = _mm256_loadu_ps(kptr + 24); - __m256 _w4 = _mm256_loadu_ps(kptr + 32); - __m256 _w5 = _mm256_loadu_ps(kptr + 40); - __m256 _w6 = _mm256_loadu_ps(kptr + 48); - __m256 _w7 = _mm256_loadu_ps(kptr + 56); - - _mm256_comp_fmadd_ps8(_sum, - _val0, _val1, _val2, _val3, _val4, _val5, _val6, _val7, - _w0, _w1, _w2, _w3, _w4, _w5, _w6, _w7); - - sptr += dilation_w * 8; - kptr += 64; - } - } - - _sum = activation_avx(_sum, activation_type, activation_params); - - _mm256_storeu_ps(outptr, _sum); - outptr += 8; - } - } - } - } - - if (elempack == 1 && out_elempack == 8) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - float* outptr = top_blob.row(p); - - for (int j = 0; j < outw; j++) - { - __m256 _sum = _mm256_set1_ps(0.f); - - if (bias_term) - { - _sum = _mm256_loadu_ps(((const float*)bias_data) + p * 8); - } - - const float* kptr = weight_data_packed.channel(p); - - for (int q = 0; q < h; q++) - { - const float* sptr = bottom_blob_bordered.row(q) + j * stride_w; - - for (int k = 0; k < kernel_w; k++) - { - __m256 _val = _mm256_set1_ps(sptr[0]); - __m256 _w = _mm256_loadu_ps(kptr); - _sum = _mm256_comp_fmadd_ps(_val, _w, _sum); - - sptr += dilation_w; - kptr += 8; - } - } - - _sum = activation_avx(_sum, activation_type, activation_params); - - _mm256_storeu_ps(outptr, _sum); - outptr += 8; - } - } - } - } - - if (elempack == 4 && out_elempack == 8) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - float* outptr = top_blob.row(p); - - for (int j = 0; j < outw; j++) - { - __m256 _sum = _mm256_set1_ps(0.f); - - if (bias_term) - { - _sum = _mm256_loadu_ps((const float*)bias_data + p * 8); - } - - const float* kptr = weight_data_packed.channel(p); - - for (int q = 0; q < h; q++) - { - const float* sptr = bottom_blob_bordered.row(q) + j * stride_w * 4; - - for (int k = 0; k < kernel_w; k++) - { - __m256 _val0 = _mm256_broadcast_ss(sptr); - __m256 _val1 = _mm256_broadcast_ss(sptr + 1); - __m256 _val2 = _mm256_broadcast_ss(sptr + 2); - __m256 _val3 = _mm256_broadcast_ss(sptr + 3); - - __m256 _w0 = _mm256_loadu_ps(kptr); - _sum = _mm256_comp_fmadd_ps(_val0, _w0, _sum); - __m256 _w1 = _mm256_loadu_ps(kptr + 8); - _sum = _mm256_comp_fmadd_ps(_val1, _w1, _sum); - __m256 _w2 = _mm256_loadu_ps(kptr + 16); - _sum = _mm256_comp_fmadd_ps(_val2, _w2, _sum); - __m256 _w3 = _mm256_loadu_ps(kptr + 24); - _sum = _mm256_comp_fmadd_ps(_val3, _w3, _sum); - - sptr += dilation_w * 4; - kptr += 32; - } - } - - _sum = activation_avx(_sum, activation_type, activation_params); - - _mm256_storeu_ps(outptr, _sum); - outptr += 8; - } - } - } - } - - if (elempack == 8 && out_elempack == 1) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - float* outptr = top_blob.row(p); - - for (int j = 0; j < outw; j++) - { - float sum = 0.f; - - if (bias_term) - { - sum = bias_data[p]; - } - - const float* kptr = weight_data_packed.channel(p); - - __m256 _sum8 = _mm256_set1_ps(0); - - for (int q = 0; q < h; q++) - { - const float* sptr = bottom_blob_bordered.row(q) + j * stride_w * 8; - - for (int k = 0; k < kernel_w; k++) // 29.23 - { - __m256 _val = _mm256_loadu_ps(sptr); - __m256 _w = _mm256_loadu_ps(kptr); - __m256 _s8 = _mm256_mul_ps(_val, _w); - _sum8 = _mm256_add_ps(_sum8, _s8); - - sptr += dilation_w * 8; - kptr += 8; - } - } - sum += _mm256_reduce_add_ps(_sum8); // dot - sum = activation_ss(sum, activation_type, activation_params); - - outptr[j] = sum; - } - } - } - } - - if (elempack == 8 && out_elempack == 4) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - float* outptr = top_blob.row(p); - - for (int j = 0; j < outw; j++) - { - __m128 _sum = _mm_set1_ps(0.f); - - if (bias_term) - { - _sum = _mm_loadu_ps((const float*)bias_data + p * 4); - } - - const float* kptr = weight_data_packed.channel(p); - - for (int q = 0; q < h; q++) - { - const float* sptr = bottom_blob_bordered.row(q) + j * stride_w * 8; - - for (int k = 0; k < kernel_w; k++) - { - __m128 _val0 = _mm_broadcast_ss(sptr); - __m128 _val1 = _mm_broadcast_ss(sptr + 1); - __m128 _val2 = _mm_broadcast_ss(sptr + 2); - __m128 _val3 = _mm_broadcast_ss(sptr + 3); - __m128 _val4 = _mm_broadcast_ss(sptr + 4); - __m128 _val5 = _mm_broadcast_ss(sptr + 5); - __m128 _val6 = _mm_broadcast_ss(sptr + 6); - __m128 _val7 = _mm_broadcast_ss(sptr + 7); - - __m128 _w0 = _mm_loadu_ps(kptr); - _sum = _mm_comp_fmadd_ps(_val0, _w0, _sum); - __m128 _w1 = _mm_loadu_ps(kptr + 4); - _sum = _mm_comp_fmadd_ps(_val1, _w1, _sum); - __m128 _w2 = _mm_loadu_ps(kptr + 8); - _sum = _mm_comp_fmadd_ps(_val2, _w2, _sum); - __m128 _w3 = _mm_loadu_ps(kptr + 12); - _sum = _mm_comp_fmadd_ps(_val3, _w3, _sum); - __m128 _w4 = _mm_loadu_ps(kptr + 16); - _sum = _mm_comp_fmadd_ps(_val4, _w4, _sum); - __m128 _w5 = _mm_loadu_ps(kptr + 20); - _sum = _mm_comp_fmadd_ps(_val5, _w5, _sum); - __m128 _w6 = _mm_loadu_ps(kptr + 24); - _sum = _mm_comp_fmadd_ps(_val6, _w6, _sum); - __m128 _w7 = _mm_loadu_ps(kptr + 28); - _sum = _mm_comp_fmadd_ps(_val7, _w7, _sum); - - sptr += dilation_w * 8; - kptr += 32; - } - } - - _sum = activation_sse(_sum, activation_type, activation_params); - - _mm_storeu_ps(outptr, _sum); - outptr += 4; - } - } - } - } -#endif - - if (elempack == 4 && out_elempack == 4) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - float* outptr = top_blob.row(p); - - for (int j = 0; j < outw; j++) - { - __m128 _sum = _mm_set1_ps(0.f); - - if (bias_term) - { - _sum = _mm_loadu_ps((const float*)bias_data + p * 4); - } - - const float* kptr = weight_data_packed.channel(p); - - for (int q = 0; q < h; q++) - { - const float* sptr = bottom_blob_bordered.row(q) + j * stride_w * 4; - - for (int k = 0; k < kernel_w; k++) - { - __m128 _val0 = _mm_set1_ps(sptr[0]); - __m128 _val1 = _mm_set1_ps(sptr[1]); - __m128 _val2 = _mm_set1_ps(sptr[2]); - __m128 _val3 = _mm_set1_ps(sptr[3]); - - __m128 _w0 = _mm_loadu_ps(kptr); - _sum = _mm_add_ps(_mm_mul_ps(_val0, _w0), _sum); - __m128 _w1 = _mm_loadu_ps(kptr + 4); - _sum = _mm_add_ps(_mm_mul_ps(_val1, _w1), _sum); - __m128 _w2 = _mm_loadu_ps(kptr + 8); - _sum = _mm_add_ps(_mm_mul_ps(_val2, _w2), _sum); - __m128 _w3 = _mm_loadu_ps(kptr + 12); - _sum = _mm_add_ps(_mm_mul_ps(_val3, _w3), _sum); - - sptr += dilation_w * 4; - kptr += 16; - } - } - - _sum = activation_sse(_sum, activation_type, activation_params); - - _mm_storeu_ps(outptr, _sum); - outptr += 4; - } - } - } - } - - if (elempack == 1 && out_elempack == 4) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - float* outptr = top_blob.row(p); - - for (int j = 0; j < outw; j++) - { - __m128 _sum = _mm_set1_ps(0.f); - - if (bias_term) - { - _sum = _mm_loadu_ps((const float*)bias_data + p * 4); - } - - const float* kptr = weight_data_packed.channel(p); - - for (int q = 0; q < h; q++) - { - const float* sptr = bottom_blob_bordered.row(q) + j * stride_w; - - for (int k = 0; k < kernel_w; k++) - { - __m128 _val = _mm_set1_ps(sptr[0]); - __m128 _w = _mm_loadu_ps(kptr); - _sum = _mm_add_ps(_mm_mul_ps(_val, _w), _sum); - - sptr += dilation_w; - kptr += 4; - } - } - - _sum = activation_sse(_sum, activation_type, activation_params); - - _mm_storeu_ps(outptr, _sum); - outptr += 4; - } - } - } - } - - if (elempack == 4 && out_elempack == 1) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - float* outptr = top_blob.row(p); - - for (int j = 0; j < outw; j++) - { - float sum = 0.f; - - if (bias_term) - { - sum = bias_data[p]; - } - - const float* kptr = weight_data_packed.channel(p); - - for (int q = 0; q < h; q++) - { - const float* sptr = bottom_blob_bordered.row(q) + j * stride_w * 4; - - for (int k = 0; k < kernel_w; k++) - { - __m128 _val = _mm_loadu_ps(sptr); - __m128 _w = _mm_loadu_ps(kptr); - __m128 _s4 = _mm_mul_ps(_val, _w); - sum += _mm_reduce_add_ps(_s4); // dot - - sptr += dilation_w * 4; - kptr += 4; - } - } - - sum = activation_ss(sum, activation_type, activation_params); - - outptr[j] = sum; - } - } - } - } -#endif // __SSE2__ - - if (elempack == 1 && out_elempack == 1) - { - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outh; p++) - { - float* outptr = top_blob.row(p); - - for (int j = 0; j < outw; j++) - { - float sum = 0.f; - - if (bias_term) - { - sum = bias_data[p]; - } - - const float* kptr = (const float*)weight_data + kernel_w * h * p; - - for (int q = 0; q < h; q++) - { - const float* sptr = bottom_blob_bordered.row(q) + j * stride_w; - - for (int k = 0; k < kernel_w; k++) - { - float val = sptr[0]; - float wt = kptr[0]; - sum += val * wt; - - sptr += dilation_w; - kptr += 1; - } - } - - sum = activation_ss(sum, activation_type, activation_params); - - outptr[j] = sum; - } - } - } - } + convolution1d_packed(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, kernel_w, dilation_w, stride_w, activation_type, activation_params, opt); return 0; } diff --git a/src/layer/x86/convolution1d_x86.h b/src/layer/x86/convolution1d_x86.h index 33665025d..ec1782b70 100644 --- a/src/layer/x86/convolution1d_x86.h +++ b/src/layer/x86/convolution1d_x86.h @@ -32,7 +32,7 @@ public: virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; public: - Mat weight_data_packed; + Mat weight_data_tm; }; } // namespace ncnn diff --git a/src/layer/x86/convolution_packed.h b/src/layer/x86/convolution_packed.h index cfc0a1425..1b797dc46 100644 --- a/src/layer/x86/convolution_packed.h +++ b/src/layer/x86/convolution_packed.h @@ -1188,22 +1188,6 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& } if (elempack == 1) { - const float* r1 = r0 + N; - const float* r2 = r0 + N * 2; - const float* r3 = r0 + N * 3; - const float* r4 = r0 + N * 4; - const float* r5 = r0 + N * 5; - const float* r6 = r0 + N * 6; - const float* r7 = r0 + N * 7; - const float* r8 = r0 + N * 8; - const float* r9 = r0 + N * 9; - const float* ra = r0 + N * 10; - const float* rb = r0 + N * 11; - const float* rc = r0 + N * 12; - const float* rd = r0 + N * 13; - const float* re = r0 + N * 14; - const float* rf = r0 + N * 15; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; @@ -1226,21 +1210,21 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& __m512 _wf = _mm512_load_ps(kptr + 16 * 15); _sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[sok]), _sum0); - _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r1[sok]), _sum1); - _sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r2[sok]), _sum2); - _sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r3[sok]), _sum3); - _sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r4[sok]), _sum0); - _sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r5[sok]), _sum1); - _sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r6[sok]), _sum2); - _sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r7[sok]), _sum3); - _sum0 = _mm512_fmadd_ps(_w8, _mm512_set1_ps(r8[sok]), _sum0); - _sum1 = _mm512_fmadd_ps(_w9, _mm512_set1_ps(r9[sok]), _sum1); - _sum2 = _mm512_fmadd_ps(_wa, _mm512_set1_ps(ra[sok]), _sum2); - _sum3 = _mm512_fmadd_ps(_wb, _mm512_set1_ps(rb[sok]), _sum3); - _sum0 = _mm512_fmadd_ps(_wc, _mm512_set1_ps(rc[sok]), _sum0); - _sum1 = _mm512_fmadd_ps(_wd, _mm512_set1_ps(rd[sok]), _sum1); - _sum2 = _mm512_fmadd_ps(_we, _mm512_set1_ps(re[sok]), _sum2); - _sum3 = _mm512_fmadd_ps(_wf, _mm512_set1_ps(rf[sok]), _sum3); + _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[sok + N]), _sum1); + _sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[sok + N * 2]), _sum2); + _sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[sok + N * 3]), _sum3); + _sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r0[sok + N * 4]), _sum0); + _sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r0[sok + N * 5]), _sum1); + _sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r0[sok + N * 6]), _sum2); + _sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r0[sok + N * 7]), _sum3); + _sum0 = _mm512_fmadd_ps(_w8, _mm512_set1_ps(r0[sok + N * 8]), _sum0); + _sum1 = _mm512_fmadd_ps(_w9, _mm512_set1_ps(r0[sok + N * 9]), _sum1); + _sum2 = _mm512_fmadd_ps(_wa, _mm512_set1_ps(r0[sok + N * 10]), _sum2); + _sum3 = _mm512_fmadd_ps(_wb, _mm512_set1_ps(r0[sok + N * 11]), _sum3); + _sum0 = _mm512_fmadd_ps(_wc, _mm512_set1_ps(r0[sok + N * 12]), _sum0); + _sum1 = _mm512_fmadd_ps(_wd, _mm512_set1_ps(r0[sok + N * 13]), _sum1); + _sum2 = _mm512_fmadd_ps(_we, _mm512_set1_ps(r0[sok + N * 14]), _sum2); + _sum3 = _mm512_fmadd_ps(_wf, _mm512_set1_ps(r0[sok + N * 15]), _sum3); kptr += 256; } @@ -1309,14 +1293,6 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& } if (elempack == 1) { - const float* r1 = r0 + N; - const float* r2 = r0 + N * 2; - const float* r3 = r0 + N * 3; - const float* r4 = r0 + N * 4; - const float* r5 = r0 + N * 5; - const float* r6 = r0 + N * 6; - const float* r7 = r0 + N * 7; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; @@ -1331,13 +1307,13 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& __m512 _w7 = _mm512_load_ps(kptr + 16 * 7); _sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[sok]), _sum0); - _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r1[sok]), _sum1); - _sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r2[sok]), _sum2); - _sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r3[sok]), _sum3); - _sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r4[sok]), _sum0); - _sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r5[sok]), _sum1); - _sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r6[sok]), _sum2); - _sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r7[sok]), _sum3); + _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[sok + N]), _sum1); + _sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[sok + N * 2]), _sum2); + _sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[sok + N * 3]), _sum3); + _sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r0[sok + N * 4]), _sum0); + _sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r0[sok + N * 5]), _sum1); + _sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r0[sok + N * 6]), _sum2); + _sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r0[sok + N * 7]), _sum3); kptr += 128; } @@ -1368,10 +1344,6 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& } if (elempack == 1) { - const float* r1 = r0 + N; - const float* r2 = r0 + N * 2; - const float* r3 = r0 + N * 3; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; @@ -1382,9 +1354,9 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& __m512 _w3 = _mm512_load_ps(kptr + 48); _sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[sok]), _sum0); - _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r1[sok]), _sum1); - _sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r2[sok]), _sum2); - _sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r3[sok]), _sum3); + _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[sok + N]), _sum1); + _sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[sok + N * 2]), _sum2); + _sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[sok + N * 3]), _sum3); kptr += 64; } @@ -1396,8 +1368,6 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& // if (elempack == 1) { - const float* r1 = r0 + N; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; @@ -1406,7 +1376,7 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& __m512 _w1 = _mm512_load_ps(kptr + 16); _sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[sok]), _sum0); - _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r1[sok]), _sum1); + _sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[sok + N]), _sum1); kptr += 32; } @@ -1659,22 +1629,6 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& } if (elempack == 1) { - const float* r1 = r0 + N; - const float* r2 = r0 + N * 2; - const float* r3 = r0 + N * 3; - const float* r4 = r0 + N * 4; - const float* r5 = r0 + N * 5; - const float* r6 = r0 + N * 6; - const float* r7 = r0 + N * 7; - const float* r8 = r0 + N * 8; - const float* r9 = r0 + N * 9; - const float* ra = r0 + N * 10; - const float* rb = r0 + N * 11; - const float* rc = r0 + N * 12; - const float* rd = r0 + N * 13; - const float* re = r0 + N * 14; - const float* rf = r0 + N * 15; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; @@ -1697,21 +1651,21 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& __m256 _wf = _mm256_load_ps(kptr + 8 * 15); _sum0 = _mm256_fmadd_ps(_w0, _mm256_set1_ps(r0[sok]), _sum0); - _sum1 = _mm256_fmadd_ps(_w1, _mm256_set1_ps(r1[sok]), _sum1); - _sum2 = _mm256_fmadd_ps(_w2, _mm256_set1_ps(r2[sok]), _sum2); - _sum3 = _mm256_fmadd_ps(_w3, _mm256_set1_ps(r3[sok]), _sum3); - _sum0 = _mm256_fmadd_ps(_w4, _mm256_set1_ps(r4[sok]), _sum0); - _sum1 = _mm256_fmadd_ps(_w5, _mm256_set1_ps(r5[sok]), _sum1); - _sum2 = _mm256_fmadd_ps(_w6, _mm256_set1_ps(r6[sok]), _sum2); - _sum3 = _mm256_fmadd_ps(_w7, _mm256_set1_ps(r7[sok]), _sum3); - _sum0 = _mm256_fmadd_ps(_w8, _mm256_set1_ps(r8[sok]), _sum0); - _sum1 = _mm256_fmadd_ps(_w9, _mm256_set1_ps(r9[sok]), _sum1); - _sum2 = _mm256_fmadd_ps(_wa, _mm256_set1_ps(ra[sok]), _sum2); - _sum3 = _mm256_fmadd_ps(_wb, _mm256_set1_ps(rb[sok]), _sum3); - _sum0 = _mm256_fmadd_ps(_wc, _mm256_set1_ps(rc[sok]), _sum0); - _sum1 = _mm256_fmadd_ps(_wd, _mm256_set1_ps(rd[sok]), _sum1); - _sum2 = _mm256_fmadd_ps(_we, _mm256_set1_ps(re[sok]), _sum2); - _sum3 = _mm256_fmadd_ps(_wf, _mm256_set1_ps(rf[sok]), _sum3); + _sum1 = _mm256_fmadd_ps(_w1, _mm256_set1_ps(r0[sok + N]), _sum1); + _sum2 = _mm256_fmadd_ps(_w2, _mm256_set1_ps(r0[sok + N * 2]), _sum2); + _sum3 = _mm256_fmadd_ps(_w3, _mm256_set1_ps(r0[sok + N * 3]), _sum3); + _sum0 = _mm256_fmadd_ps(_w4, _mm256_set1_ps(r0[sok + N * 4]), _sum0); + _sum1 = _mm256_fmadd_ps(_w5, _mm256_set1_ps(r0[sok + N * 5]), _sum1); + _sum2 = _mm256_fmadd_ps(_w6, _mm256_set1_ps(r0[sok + N * 6]), _sum2); + _sum3 = _mm256_fmadd_ps(_w7, _mm256_set1_ps(r0[sok + N * 7]), _sum3); + _sum0 = _mm256_fmadd_ps(_w8, _mm256_set1_ps(r0[sok + N * 8]), _sum0); + _sum1 = _mm256_fmadd_ps(_w9, _mm256_set1_ps(r0[sok + N * 9]), _sum1); + _sum2 = _mm256_fmadd_ps(_wa, _mm256_set1_ps(r0[sok + N * 10]), _sum2); + _sum3 = _mm256_fmadd_ps(_wb, _mm256_set1_ps(r0[sok + N * 11]), _sum3); + _sum0 = _mm256_fmadd_ps(_wc, _mm256_set1_ps(r0[sok + N * 12]), _sum0); + _sum1 = _mm256_fmadd_ps(_wd, _mm256_set1_ps(r0[sok + N * 13]), _sum1); + _sum2 = _mm256_fmadd_ps(_we, _mm256_set1_ps(r0[sok + N * 14]), _sum2); + _sum3 = _mm256_fmadd_ps(_wf, _mm256_set1_ps(r0[sok + N * 15]), _sum3); kptr += 128; } @@ -1781,14 +1735,6 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& } if (elempack == 1) { - const float* r1 = r0 + N; - const float* r2 = r0 + N * 2; - const float* r3 = r0 + N * 3; - const float* r4 = r0 + N * 4; - const float* r5 = r0 + N * 5; - const float* r6 = r0 + N * 6; - const float* r7 = r0 + N * 7; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; @@ -1803,13 +1749,13 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& __m256 _w7 = _mm256_load_ps(kptr + 56); _sum0 = _mm256_comp_fmadd_ps(_w0, _mm256_set1_ps(r0[sok]), _sum0); - _sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r1[sok]), _sum1); - _sum2 = _mm256_comp_fmadd_ps(_w2, _mm256_set1_ps(r2[sok]), _sum2); - _sum3 = _mm256_comp_fmadd_ps(_w3, _mm256_set1_ps(r3[sok]), _sum3); - _sum0 = _mm256_comp_fmadd_ps(_w4, _mm256_set1_ps(r4[sok]), _sum0); - _sum1 = _mm256_comp_fmadd_ps(_w5, _mm256_set1_ps(r5[sok]), _sum1); - _sum2 = _mm256_comp_fmadd_ps(_w6, _mm256_set1_ps(r6[sok]), _sum2); - _sum3 = _mm256_comp_fmadd_ps(_w7, _mm256_set1_ps(r7[sok]), _sum3); + _sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r0[sok + N]), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_w2, _mm256_set1_ps(r0[sok + N * 2]), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_w3, _mm256_set1_ps(r0[sok + N * 3]), _sum3); + _sum0 = _mm256_comp_fmadd_ps(_w4, _mm256_set1_ps(r0[sok + N * 4]), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_w5, _mm256_set1_ps(r0[sok + N * 5]), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_w6, _mm256_set1_ps(r0[sok + N * 6]), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_w7, _mm256_set1_ps(r0[sok + N * 7]), _sum3); kptr += 64; } @@ -1840,10 +1786,6 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& } if (elempack == 1) { - const float* r1 = r0 + N; - const float* r2 = r0 + N * 2; - const float* r3 = r0 + N * 3; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; @@ -1854,9 +1796,9 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& __m256 _w3 = _mm256_load_ps(kptr + 24); _sum0 = _mm256_comp_fmadd_ps(_w0, _mm256_set1_ps(r0[sok]), _sum0); - _sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r1[sok]), _sum1); - _sum2 = _mm256_comp_fmadd_ps(_w2, _mm256_set1_ps(r2[sok]), _sum2); - _sum3 = _mm256_comp_fmadd_ps(_w3, _mm256_set1_ps(r3[sok]), _sum3); + _sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r0[sok + N]), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_w2, _mm256_set1_ps(r0[sok + N * 2]), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_w3, _mm256_set1_ps(r0[sok + N * 3]), _sum3); kptr += 32; } @@ -1868,8 +1810,6 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& // if (elempack == 1) { - const float* r1 = r0 + N; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; @@ -1878,7 +1818,7 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& __m256 _w1 = _mm256_load_ps(kptr + 8); _sum0 = _mm256_comp_fmadd_ps(_w0, _mm256_set1_ps(r0[sok]), _sum0); - _sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r1[sok]), _sum1); + _sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r0[sok + N]), _sum1); kptr += 16; } @@ -2118,22 +2058,6 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& } if (elempack == 1) { - const float* r1 = r0 + N; - const float* r2 = r0 + N * 2; - const float* r3 = r0 + N * 3; - const float* r4 = r0 + N * 4; - const float* r5 = r0 + N * 5; - const float* r6 = r0 + N * 6; - const float* r7 = r0 + N * 7; - const float* r8 = r0 + N * 8; - const float* r9 = r0 + N * 9; - const float* ra = r0 + N * 10; - const float* rb = r0 + N * 11; - const float* rc = r0 + N * 12; - const float* rd = r0 + N * 13; - const float* re = r0 + N * 14; - const float* rf = r0 + N * 15; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; @@ -2156,21 +2080,21 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& __m128 _wf = _mm_load_ps(kptr + 4 * 15); _sum0 = _mm_fmadd_ps(_w0, _mm_set1_ps(r0[sok]), _sum0); - _sum1 = _mm_fmadd_ps(_w1, _mm_set1_ps(r1[sok]), _sum1); - _sum2 = _mm_fmadd_ps(_w2, _mm_set1_ps(r2[sok]), _sum2); - _sum3 = _mm_fmadd_ps(_w3, _mm_set1_ps(r3[sok]), _sum3); - _sum0 = _mm_fmadd_ps(_w4, _mm_set1_ps(r4[sok]), _sum0); - _sum1 = _mm_fmadd_ps(_w5, _mm_set1_ps(r5[sok]), _sum1); - _sum2 = _mm_fmadd_ps(_w6, _mm_set1_ps(r6[sok]), _sum2); - _sum3 = _mm_fmadd_ps(_w7, _mm_set1_ps(r7[sok]), _sum3); - _sum0 = _mm_fmadd_ps(_w8, _mm_set1_ps(r8[sok]), _sum0); - _sum1 = _mm_fmadd_ps(_w9, _mm_set1_ps(r9[sok]), _sum1); - _sum2 = _mm_fmadd_ps(_wa, _mm_set1_ps(ra[sok]), _sum2); - _sum3 = _mm_fmadd_ps(_wb, _mm_set1_ps(rb[sok]), _sum3); - _sum0 = _mm_fmadd_ps(_wc, _mm_set1_ps(rc[sok]), _sum0); - _sum1 = _mm_fmadd_ps(_wd, _mm_set1_ps(rd[sok]), _sum1); - _sum2 = _mm_fmadd_ps(_we, _mm_set1_ps(re[sok]), _sum2); - _sum3 = _mm_fmadd_ps(_wf, _mm_set1_ps(rf[sok]), _sum3); + _sum1 = _mm_fmadd_ps(_w1, _mm_set1_ps(r0[sok + N]), _sum1); + _sum2 = _mm_fmadd_ps(_w2, _mm_set1_ps(r0[sok + N * 2]), _sum2); + _sum3 = _mm_fmadd_ps(_w3, _mm_set1_ps(r0[sok + N * 3]), _sum3); + _sum0 = _mm_fmadd_ps(_w4, _mm_set1_ps(r0[sok + N * 4]), _sum0); + _sum1 = _mm_fmadd_ps(_w5, _mm_set1_ps(r0[sok + N * 5]), _sum1); + _sum2 = _mm_fmadd_ps(_w6, _mm_set1_ps(r0[sok + N * 6]), _sum2); + _sum3 = _mm_fmadd_ps(_w7, _mm_set1_ps(r0[sok + N * 7]), _sum3); + _sum0 = _mm_fmadd_ps(_w8, _mm_set1_ps(r0[sok + N * 8]), _sum0); + _sum1 = _mm_fmadd_ps(_w9, _mm_set1_ps(r0[sok + N * 9]), _sum1); + _sum2 = _mm_fmadd_ps(_wa, _mm_set1_ps(r0[sok + N * 10]), _sum2); + _sum3 = _mm_fmadd_ps(_wb, _mm_set1_ps(r0[sok + N * 11]), _sum3); + _sum0 = _mm_fmadd_ps(_wc, _mm_set1_ps(r0[sok + N * 12]), _sum0); + _sum1 = _mm_fmadd_ps(_wd, _mm_set1_ps(r0[sok + N * 13]), _sum1); + _sum2 = _mm_fmadd_ps(_we, _mm_set1_ps(r0[sok + N * 14]), _sum2); + _sum3 = _mm_fmadd_ps(_wf, _mm_set1_ps(r0[sok + N * 15]), _sum3); kptr += 64; } @@ -2240,14 +2164,6 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& } if (elempack == 1) { - const float* r1 = r0 + N; - const float* r2 = r0 + N * 2; - const float* r3 = r0 + N * 3; - const float* r4 = r0 + N * 4; - const float* r5 = r0 + N * 5; - const float* r6 = r0 + N * 6; - const float* r7 = r0 + N * 7; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; @@ -2262,13 +2178,13 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& __m128 _w7 = _mm_load_ps(kptr + 28); _sum0 = _mm_comp_fmadd_ps(_w0, _mm_set1_ps(r0[sok]), _sum0); - _sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r1[sok]), _sum1); - _sum2 = _mm_comp_fmadd_ps(_w2, _mm_set1_ps(r2[sok]), _sum2); - _sum3 = _mm_comp_fmadd_ps(_w3, _mm_set1_ps(r3[sok]), _sum3); - _sum0 = _mm_comp_fmadd_ps(_w4, _mm_set1_ps(r4[sok]), _sum0); - _sum1 = _mm_comp_fmadd_ps(_w5, _mm_set1_ps(r5[sok]), _sum1); - _sum2 = _mm_comp_fmadd_ps(_w6, _mm_set1_ps(r6[sok]), _sum2); - _sum3 = _mm_comp_fmadd_ps(_w7, _mm_set1_ps(r7[sok]), _sum3); + _sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r0[sok + N]), _sum1); + _sum2 = _mm_comp_fmadd_ps(_w2, _mm_set1_ps(r0[sok + N * 2]), _sum2); + _sum3 = _mm_comp_fmadd_ps(_w3, _mm_set1_ps(r0[sok + N * 3]), _sum3); + _sum0 = _mm_comp_fmadd_ps(_w4, _mm_set1_ps(r0[sok + N * 4]), _sum0); + _sum1 = _mm_comp_fmadd_ps(_w5, _mm_set1_ps(r0[sok + N * 5]), _sum1); + _sum2 = _mm_comp_fmadd_ps(_w6, _mm_set1_ps(r0[sok + N * 6]), _sum2); + _sum3 = _mm_comp_fmadd_ps(_w7, _mm_set1_ps(r0[sok + N * 7]), _sum3); kptr += 32; } @@ -2300,10 +2216,6 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& } if (elempack == 1) { - const float* r1 = r0 + N; - const float* r2 = r0 + N * 2; - const float* r3 = r0 + N * 3; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; @@ -2314,9 +2226,9 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& __m128 _w3 = _mm_load_ps(kptr + 12); _sum0 = _mm_comp_fmadd_ps(_w0, _mm_set1_ps(r0[sok]), _sum0); - _sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r1[sok]), _sum1); - _sum2 = _mm_comp_fmadd_ps(_w2, _mm_set1_ps(r2[sok]), _sum2); - _sum3 = _mm_comp_fmadd_ps(_w3, _mm_set1_ps(r3[sok]), _sum3); + _sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r0[sok + N]), _sum1); + _sum2 = _mm_comp_fmadd_ps(_w2, _mm_set1_ps(r0[sok + N * 2]), _sum2); + _sum3 = _mm_comp_fmadd_ps(_w3, _mm_set1_ps(r0[sok + N * 3]), _sum3); kptr += 16; } @@ -2328,8 +2240,6 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& // if (elempack == 1) { - const float* r1 = r0 + N; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; @@ -2338,7 +2248,7 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& __m128 _w1 = _mm_load_ps(kptr + 4); _sum0 = _mm_comp_fmadd_ps(_w0, _mm_set1_ps(r0[sok]), _sum0); - _sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r1[sok]), _sum1); + _sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r0[sok + N]), _sum1); kptr += 8; } @@ -2482,26 +2392,10 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& } if (elempack == 1) { - const float* r1 = r0 + N; - const float* r2 = r0 + N * 2; - const float* r3 = r0 + N * 3; - const float* r4 = r0 + N * 4; - const float* r5 = r0 + N * 5; - const float* r6 = r0 + N * 6; - const float* r7 = r0 + N * 7; - const float* r8 = r0 + N * 8; - const float* r9 = r0 + N * 9; - const float* ra = r0 + N * 10; - const float* rb = r0 + N * 11; - const float* rc = r0 + N * 12; - const float* rd = r0 + N * 13; - const float* re = r0 + N * 14; - const float* rf = r0 + N * 15; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; - __m512 _r0 = _mm512_set_ps(rf[sok], re[sok], rd[sok], rc[sok], rb[sok], ra[sok], r9[sok], r8[sok], r7[sok], r6[sok], r5[sok], r4[sok], r3[sok], r2[sok], r1[sok], r0[sok]); + __m512 _r0 = _mm512_set_ps(r0[sok + N * 15], r0[sok + N * 14], r0[sok + N * 13], r0[sok + N * 12], r0[sok + N * 11], r0[sok + N * 10], r0[sok + N * 9], r0[sok + N * 8], r0[sok + N * 7], r0[sok + N * 6], r0[sok + N * 5], r0[sok + N * 4], r0[sok + N * 3], r0[sok + N * 2], r0[sok + N], r0[sok]); __m512 _w0 = _mm512_load_ps(kptr); __m512 _w1 = _mm512_load_ps(kptr + 16); _sum0_avx512 = _mm512_fmadd_ps(_r0, _w0, _sum0_avx512); @@ -2552,18 +2446,10 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& } if (elempack == 1) { - const float* r1 = r0 + N; - const float* r2 = r0 + N * 2; - const float* r3 = r0 + N * 3; - const float* r4 = r0 + N * 4; - const float* r5 = r0 + N * 5; - const float* r6 = r0 + N * 6; - const float* r7 = r0 + N * 7; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; - __m256 _r0 = _mm256_set_ps(r7[sok], r6[sok], r5[sok], r4[sok], r3[sok], r2[sok], r1[sok], r0[sok]); + __m256 _r0 = _mm256_set_ps(r0[sok + N * 7], r0[sok + N * 6], r0[sok + N * 5], r0[sok + N * 4], r0[sok + N * 3], r0[sok + N * 2], r0[sok + N], r0[sok]); __m256 _w0 = _mm256_load_ps(kptr); __m256 _w1 = _mm256_load_ps(kptr + 8); _sum0_avx = _mm256_comp_fmadd_ps(_r0, _w0, _sum0_avx); @@ -2598,14 +2484,10 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& } if (elempack == 1) { - const float* r1 = r0 + N; - const float* r2 = r0 + N * 2; - const float* r3 = r0 + N * 3; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; - __m128 _r0 = _mm_set_ps(r3[sok], r2[sok], r1[sok], r0[sok]); + __m128 _r0 = _mm_set_ps(r0[sok + N * 3], r0[sok + N * 2], r0[sok + N], r0[sok]); __m128 _w0 = _mm_load_ps(kptr); __m128 _w1 = _mm_load_ps(kptr + 4); _sum0 = _mm_comp_fmadd_ps(_r0, _w0, _sum0); @@ -2624,16 +2506,14 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& // if (elempack == 1) { - const float* r1 = r0 + N; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; sum0 += r0[sok] * kptr[0]; sum1 += r0[sok] * kptr[1]; - sum0 += r1[sok] * kptr[2]; - sum1 += r1[sok] * kptr[3]; + sum0 += r0[sok + N] * kptr[2]; + sum1 += r0[sok + N] * kptr[3]; kptr += 4; } @@ -2745,26 +2625,10 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& } if (elempack == 1) { - const float* r1 = r0 + N; - const float* r2 = r0 + N * 2; - const float* r3 = r0 + N * 3; - const float* r4 = r0 + N * 4; - const float* r5 = r0 + N * 5; - const float* r6 = r0 + N * 6; - const float* r7 = r0 + N * 7; - const float* r8 = r0 + N * 8; - const float* r9 = r0 + N * 9; - const float* ra = r0 + N * 10; - const float* rb = r0 + N * 11; - const float* rc = r0 + N * 12; - const float* rd = r0 + N * 13; - const float* re = r0 + N * 14; - const float* rf = r0 + N * 15; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; - __m512 _r0 = _mm512_set_ps(rf[sok], re[sok], rd[sok], rc[sok], rb[sok], ra[sok], r9[sok], r8[sok], r7[sok], r6[sok], r5[sok], r4[sok], r3[sok], r2[sok], r1[sok], r0[sok]); + __m512 _r0 = _mm512_set_ps(r0[sok + N * 15], r0[sok + N * 14], r0[sok + N * 13], r0[sok + N * 12], r0[sok + N * 11], r0[sok + N * 10], r0[sok + N * 9], r0[sok + N * 8], r0[sok + N * 7], r0[sok + N * 6], r0[sok + N * 5], r0[sok + N * 4], r0[sok + N * 3], r0[sok + N * 2], r0[sok + N], r0[sok]); __m512 _w = _mm512_load_ps(kptr); _sum_avx512 = _mm512_fmadd_ps(_r0, _w, _sum_avx512); @@ -2807,18 +2671,10 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& } if (elempack == 1) { - const float* r1 = r0 + N; - const float* r2 = r0 + N * 2; - const float* r3 = r0 + N * 3; - const float* r4 = r0 + N * 4; - const float* r5 = r0 + N * 5; - const float* r6 = r0 + N * 6; - const float* r7 = r0 + N * 7; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; - __m256 _r0 = _mm256_set_ps(r7[sok], r6[sok], r5[sok], r4[sok], r3[sok], r2[sok], r1[sok], r0[sok]); + __m256 _r0 = _mm256_set_ps(r0[sok + N * 7], r0[sok + N * 6], r0[sok + N * 5], r0[sok + N * 4], r0[sok + N * 3], r0[sok + N * 2], r0[sok + N], r0[sok]); __m256 _w = _mm256_load_ps(kptr); _sum_avx = _mm256_comp_fmadd_ps(_r0, _w, _sum_avx); @@ -2847,14 +2703,10 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& } if (elempack == 1) { - const float* r1 = r0 + N; - const float* r2 = r0 + N * 2; - const float* r3 = r0 + N * 3; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; - __m128 _r0 = _mm_set_ps(r3[sok], r2[sok], r1[sok], r0[sok]); + __m128 _r0 = _mm_set_ps(r0[sok + N * 3], r0[sok + N * 2], r0[sok + N], r0[sok]); __m128 _w = _mm_load_ps(kptr); _sum = _mm_comp_fmadd_ps(_r0, _w, _sum); @@ -2870,14 +2722,12 @@ static void convolution_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& // if (elempack == 1) { - const float* r1 = r0 + N; - for (int k = 0; k < maxk; k++) { const int sok = space_ofs[k]; sum += r0[sok] * kptr[0]; - sum += r1[sok] * kptr[1]; + sum += r0[sok + N] * kptr[1]; kptr += 2; } diff --git a/tests/test_convolution1d.cpp b/tests/test_convolution1d.cpp index bf8f267d7..c8dd55ffe 100644 --- a/tests/test_convolution1d.cpp +++ b/tests/test_convolution1d.cpp @@ -76,38 +76,50 @@ static int test_convolution1d_0() const int d = kdsp[i][1]; const int s = kdsp[i][2]; const int p = kdsp[i][3]; + const int b0 = i % 2; + const int b1 = 1 - b1; int ret = 0 - || test_convolution1d(9, 1, 1, k, d, s, p, 1) - || test_convolution1d(9, 4, 13, k, d, s, p, 0) - || test_convolution1d(9, 13, 4, k, d, s, p, 1) - || test_convolution1d(9, 12, 12, k, d, s, p, 0) - || test_convolution1d(9, 8, 12, k, d, s, p, 1) - || test_convolution1d(9, 8, 13, k, d, s, p, 0) - || test_convolution1d(9, 13, 8, k, d, s, p, 1) - || test_convolution1d(9, 12, 16, k, d, s, p, 0) - || test_convolution1d(9, 15, 15, k, d, s, p, 0) - || test_convolution1d(9, 16, 16, k, d, s, p, 0) - || test_convolution1d(18, 1, 1, k, d, s, p, 1) - || test_convolution1d(18, 4, 13, k, d, s, p, 0) - || test_convolution1d(18, 13, 4, k, d, s, p, 1) - || test_convolution1d(18, 12, 12, k, d, s, p, 0) - || test_convolution1d(18, 8, 12, k, d, s, p, 1) - || test_convolution1d(18, 8, 13, k, d, s, p, 0) - || test_convolution1d(18, 13, 8, k, d, s, p, 1) - || test_convolution1d(18, 12, 16, k, d, s, p, 0) - || test_convolution1d(18, 15, 15, k, d, s, p, 0) - || test_convolution1d(18, 16, 16, k, d, s, p, 0) - || test_convolution1d(25, 1, 1, k, d, s, p, 1) - || test_convolution1d(25, 4, 13, k, d, s, p, 0) - || test_convolution1d(25, 13, 4, k, d, s, p, 1) - || test_convolution1d(25, 12, 12, k, d, s, p, 0) - || test_convolution1d(25, 8, 12, k, d, s, p, 1) - || test_convolution1d(25, 8, 13, k, d, s, p, 0) - || test_convolution1d(25, 13, 8, k, d, s, p, 1) - || test_convolution1d(25, 12, 16, k, d, s, p, 0) - || test_convolution1d(25, 15, 15, k, d, s, p, 0) - || test_convolution1d(25, 16, 16, k, d, s, p, 0); + || test_convolution1d(9, 1, 1, k, d, s, p, b0) + || test_convolution1d(9, 1, 3, k, d, s, p, b1) + || test_convolution1d(9, 1, 7, k, d, s, p, b0) + || test_convolution1d(9, 1, 15, k, d, s, p, b1) + || test_convolution1d(9, 1, 31, k, d, s, p, b0) + || test_convolution1d(9, 3, 1, k, d, s, p, b1) + || test_convolution1d(9, 3, 3, k, d, s, p, b0) + || test_convolution1d(9, 3, 7, k, d, s, p, b1) + || test_convolution1d(9, 3, 15, k, d, s, p, b0) + || test_convolution1d(9, 3, 31, k, d, s, p, b1) + || test_convolution1d(9, 7, 1, k, d, s, p, b0) + || test_convolution1d(9, 7, 3, k, d, s, p, b1) + || test_convolution1d(9, 7, 7, k, d, s, p, b0) + || test_convolution1d(9, 7, 15, k, d, s, p, b1) + || test_convolution1d(9, 7, 31, k, d, s, p, b0) + || test_convolution1d(9, 15, 1, k, d, s, p, b1) + || test_convolution1d(9, 15, 3, k, d, s, p, b0) + || test_convolution1d(9, 15, 7, k, d, s, p, b1) + || test_convolution1d(9, 15, 15, k, d, s, p, b0) + || test_convolution1d(9, 15, 31, k, d, s, p, b1) + || test_convolution1d(9, 31, 1, k, d, s, p, b0) + || test_convolution1d(9, 31, 3, k, d, s, p, b1) + || test_convolution1d(9, 31, 7, k, d, s, p, b0) + || test_convolution1d(9, 31, 15, k, d, s, p, b1) + || test_convolution1d(25, 28, 31, k, d, s, p, b0) + || test_convolution1d(25, 31, 28, k, d, s, p, b1) + || test_convolution1d(25, 28, 28, k, d, s, p, b0) + || test_convolution1d(25, 24, 28, k, d, s, p, b1) + || test_convolution1d(25, 24, 31, k, d, s, p, b0) + || test_convolution1d(25, 28, 24, k, d, s, p, b1) + || test_convolution1d(25, 31, 24, k, d, s, p, b0) + || test_convolution1d(25, 24, 24, k, d, s, p, b1) + || test_convolution1d(25, 28, 48, k, d, s, p, b0) + || test_convolution1d(25, 31, 48, k, d, s, p, b1) + || test_convolution1d(25, 24, 48, k, d, s, p, b0) + || test_convolution1d(25, 48, 28, k, d, s, p, b1) + || test_convolution1d(25, 48, 31, k, d, s, p, b0) + || test_convolution1d(25, 48, 24, k, d, s, p, b1) + || test_convolution1d(25, 31, 31, k, d, s, p, b0) + || test_convolution1d(25, 48, 48, k, d, s, p, b1); if (ret != 0) return -1;