|
- // BUG1989 is pleased to support the open source community by supporting ncnn available.
- //
- // Copyright (C) 2019 BUG1989. All rights reserved.
- //
- // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
- // in compliance with the License. You may obtain a copy of the License at
- //
- // https://opensource.org/licenses/BSD-3-Clause
- //
- // Unless required by applicable law or agreed to in writing, software distributed
- // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
- // CONDITIONS OF ANY KIND, either express or implied. See the License for the
- // specific language governing permissions and limitations under the License.
-
- #if __AVX__
- static void conv_im2col_sgemm_transform_kernel_sse(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size)
- {
-
- const float* kernel = _kernel;
-
- // kernel memory packed 8 x 8
- kernel_tm.create(8*kernel_size, inch, outch/8 + (outch%8)/4 + outch%4);
-
- int nn_outch = 0;
- int remain_outch_start = 0;
-
- nn_outch = outch >> 3;
- remain_outch_start = nn_outch << 3;
-
- for (int pp=0; pp<nn_outch; pp++)
- {
- int p = pp * 8;
-
- const float* k0 = kernel + (p+0)*inch*kernel_size;
- const float* k1 = kernel + (p+1)*inch*kernel_size;
- const float* k2 = kernel + (p+2)*inch*kernel_size;
- const float* k3 = kernel + (p+3)*inch*kernel_size;
- const float* k4 = kernel + (p+4)*inch*kernel_size;
- const float* k5 = kernel + (p+5)*inch*kernel_size;
- const float* k6 = kernel + (p+6)*inch*kernel_size;
- const float* k7 = kernel + (p+7)*inch*kernel_size;
-
- float* ktmp = kernel_tm.channel(p/8);
-
- for (int q=0; q<inch*kernel_size; q++)
- {
- ktmp[0] = k0[0];
- ktmp[1] = k1[0];
- ktmp[2] = k2[0];
- ktmp[3] = k3[0];
- ktmp[4] = k4[0];
- ktmp[5] = k5[0];
- ktmp[6] = k6[0];
- ktmp[7] = k7[0];
- ktmp += 8;
-
- k0 += 1;
- k1 += 1;
- k2 += 1;
- k3 += 1;
- k4 += 1;
- k5 += 1;
- k6 += 1;
- k7 += 1;
- }
- }
-
- nn_outch = (outch - remain_outch_start) >> 2;
-
- for (int pp=0; pp<nn_outch; pp++)
- {
- int p = remain_outch_start + pp * 4;
-
- const float* k0 = kernel + (p+0)*inch*kernel_size;
- const float* k1 = kernel + (p+1)*inch*kernel_size;
- const float* k2 = kernel + (p+2)*inch*kernel_size;
- const float* k3 = kernel + (p+3)*inch*kernel_size;
-
- float* ktmp = kernel_tm.channel(p/8 + (p%8)/4);
-
- for (int q=0; q<inch*kernel_size; q++)
- {
- ktmp[0] = k0[0];
- ktmp[1] = k1[0];
- ktmp[2] = k2[0];
- ktmp[3] = k3[0];
- ktmp += 4;
-
- k0 += 1;
- k1 += 1;
- k2 += 1;
- k3 += 1;
- }
- }
-
- remain_outch_start += nn_outch << 2;
-
- for (int p=remain_outch_start; p<outch; p++)
- {
- const float* k0 = kernel + (p+0)*inch*kernel_size;
-
- float* ktmp = kernel_tm.channel(p/8 + (p%8)/4 + p%4);
-
- for (int q=0; q<inch*kernel_size; q++)
- {
- ktmp[0] = k0[0];
- ktmp++;
- k0++;
- }
- }
- }
-
- static void conv_im2col_sgemm_sse(const Mat &bottom_blob, Mat &top_blob, const Mat & kernel_tm, const Mat& _bias, \
- const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt)
- {
- int w = bottom_blob.w;
- int inch = bottom_blob.c;
- size_t elemsize = bottom_blob.elemsize;
-
- int outw = top_blob.w;
- int outh = top_blob.h;
- int outch = top_blob.c;
-
- const float* bias = _bias;
-
- // im2col
- Mat bottom_im2col(outw*outh, kernel_h*kernel_w*inch, elemsize, opt.workspace_allocator);
- {
- const int stride = kernel_h*kernel_w*outw*outh;
- float* ret = (float*)bottom_im2col;
-
- #pragma omp parallel for num_threads(opt.num_threads)
- for (int p=0; p<inch; p++)
- {
- const float* input = bottom_blob.channel(p);
- int retID = stride * p;
- for (int u=0; u<kernel_h; u++)
- {
- for (int v=0; v<kernel_w; v++)
- {
- for (int i=0; i<outh; i++)
- {
- for (int j=0; j<outw; j++)
- {
- int row = u + i * stride_h;
- int col = v + j * stride_w;
- int index = row * w + col;
- ret[retID] = input[index];
- retID++;
- }
- }
- }
- }
- }
- }
-
- int kernel_size = kernel_w * kernel_h;
- int out_size = outw * outh;
-
- // bottom_im2col memory packed 8 x 8
- Mat bottom_tm(8*kernel_size, inch, out_size/8 + out_size%8, elemsize, opt.workspace_allocator);
- {
- int nn_size = out_size >> 3;
- int remain_size_start = nn_size << 3;
-
- #pragma omp parallel for num_threads(opt.num_threads)
- for (int ii=0; ii<nn_size; ii++)
- {
- int i = ii * 8;
-
- const float* img0 = bottom_im2col.channel(0);
- img0 += i;
-
- float* tmpptr = bottom_tm.channel(i/8);
-
- for (int q=0; q<inch*kernel_size; q++)
- {
- #if __AVX__
- _mm256_storeu_ps(tmpptr, _mm256_loadu_ps(img0));
- #else
- tmpptr[0] = img0[0];
- tmpptr[1] = img0[1];
- tmpptr[2] = img0[2];
- tmpptr[3] = img0[3];
- tmpptr[4] = img0[4];
- tmpptr[5] = img0[5];
- tmpptr[6] = img0[6];
- tmpptr[7] = img0[7];
- #endif // __SSE__
- tmpptr += 8;
- img0 += out_size;
- }
- }
-
- #pragma omp parallel for num_threads(opt.num_threads)
- for (int i=remain_size_start; i<out_size; i++)
- {
- const float* img0 = bottom_im2col.channel(0);
- img0 += i;
-
- float* tmpptr = bottom_tm.channel(i/8 + i%8);
-
- for (int q=0; q<inch*kernel_size; q++)
- {
- tmpptr[0] = img0[0];
-
- tmpptr += 1;
- img0 += out_size;
- }
- }
- }
-
- // sgemm(int M, int N, int L, float* A, float* B, float* C)
- {
- //int M = outch; // outch
- int N = outw * outh; // outsize or out stride
- int L = kernel_w * kernel_h * inch; // ksize * inch
-
- int nn_outch = 0;
- int remain_outch_start = 0;
-
- nn_outch = outch >> 3;
- remain_outch_start = nn_outch << 3;
-
- #pragma omp parallel for num_threads(opt.num_threads)
- for (int pp=0; pp<nn_outch; pp++)
- {
- int i = pp * 8;
-
- float* output0 = top_blob.channel(i);
- float* output1 = top_blob.channel(i+1);
- float* output2 = top_blob.channel(i+2);
- float* output3 = top_blob.channel(i+3);
- float* output4 = top_blob.channel(i+4);
- float* output5 = top_blob.channel(i+5);
- float* output6 = top_blob.channel(i+6);
- float* output7 = top_blob.channel(i+7);
-
- const float zeros[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
- const float* biasptr = bias ? bias + i : zeros;
-
- int j=0;
- for (; j+7<N; j=j+8)
- {
- const float* vb = bottom_tm.channel(j/8);
- const float* va = kernel_tm.channel(i/8);
- #if __AVX__
- __m256 _sum0 = _mm256_broadcast_ss(biasptr);
- __m256 _sum1 = _mm256_broadcast_ss(biasptr+1);
- __m256 _sum2 = _mm256_broadcast_ss(biasptr+2);
- __m256 _sum3 = _mm256_broadcast_ss(biasptr+3);
- __m256 _sum4 = _mm256_broadcast_ss(biasptr+4);
- __m256 _sum5 = _mm256_broadcast_ss(biasptr+5);
- __m256 _sum6 = _mm256_broadcast_ss(biasptr+6);
- __m256 _sum7 = _mm256_broadcast_ss(biasptr+7);
-
- int k=0;
- for (; k+3<L; k=k+4)
- {
- // k0
- __m256 _va0 = _mm256_broadcast_ss(va);
- __m256 _va1 = _mm256_broadcast_ss(va+1);
- __m256 _va2 = _mm256_broadcast_ss(va+2);
- __m256 _va3 = _mm256_broadcast_ss(va+3);
- __m256 _vb0 = _mm256_loadu_ps(vb);
- __m256 _vb1 = _mm256_loadu_ps(vb+8);
- __m256 _vb2 = _mm256_loadu_ps(vb+16);
- __m256 _vb3 = _mm256_loadu_ps(vb+24);
- _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
- _sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10
- _sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20
- _sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30
- _va0 = _mm256_broadcast_ss(va+4);
- _va1 = _mm256_broadcast_ss(va+5);
- _va2 = _mm256_broadcast_ss(va+6);
- _va3 = _mm256_broadcast_ss(va+7);
- _sum4 = _mm256_fmadd_ps(_vb0, _va0, _sum4); // sum4 = (a00-a07) * k40
- _sum5 = _mm256_fmadd_ps(_vb0, _va1, _sum5); // sum5 = (a00-a07) * k50
- _sum6 = _mm256_fmadd_ps(_vb0, _va2, _sum6); // sum6 = (a00-a07) * k60
- _sum7 = _mm256_fmadd_ps(_vb0, _va3, _sum7); // sum7 = (a00-a07) * k70
-
- va += 8;
-
- // k1
- _va0 = _mm256_broadcast_ss(va);
- _va1 = _mm256_broadcast_ss(va+1);
- _va2 = _mm256_broadcast_ss(va+2);
- _va3 = _mm256_broadcast_ss(va+3);
- _sum0 = _mm256_fmadd_ps(_vb1, _va0, _sum0); // sum0 += (a10-a17) * k01
- _sum1 = _mm256_fmadd_ps(_vb1, _va1, _sum1); // sum1 += (a10-a17) * k11
- _sum2 = _mm256_fmadd_ps(_vb1, _va2, _sum2); // sum2 += (a10-a17) * k21
- _sum3 = _mm256_fmadd_ps(_vb1, _va3, _sum3); // sum3 += (a10-a17) * k31
- _va0 = _mm256_broadcast_ss(va+4);
- _va1 = _mm256_broadcast_ss(va+5);
- _va2 = _mm256_broadcast_ss(va+6);
- _va3 = _mm256_broadcast_ss(va+7);
- _sum4 = _mm256_fmadd_ps(_vb1, _va0, _sum4); // sum4 += (a10-a17) * k41
- _sum5 = _mm256_fmadd_ps(_vb1, _va1, _sum5); // sum5 += (a10-a17) * k51
- _sum6 = _mm256_fmadd_ps(_vb1, _va2, _sum6); // sum6 += (a10-a17) * k61
- _sum7 = _mm256_fmadd_ps(_vb1, _va3, _sum7); // sum7 += (a10-a17) * k71
-
- va += 8;
-
- // k2
- _va0 = _mm256_broadcast_ss(va);
- _va1 = _mm256_broadcast_ss(va+1);
- _va2 = _mm256_broadcast_ss(va+2);
- _va3 = _mm256_broadcast_ss(va+3);
- _sum0 = _mm256_fmadd_ps(_vb2, _va0, _sum0); // sum0 += (a20-a27) * k02
- _sum1 = _mm256_fmadd_ps(_vb2, _va1, _sum1); // sum1 += (a20-a27) * k12
- _sum2 = _mm256_fmadd_ps(_vb2, _va2, _sum2); // sum2 += (a20-a27) * k22
- _sum3 = _mm256_fmadd_ps(_vb2, _va3, _sum3); // sum3 += (a20-a27) * k32
- _va0 = _mm256_broadcast_ss(va+4);
- _va1 = _mm256_broadcast_ss(va+5);
- _va2 = _mm256_broadcast_ss(va+6);
- _va3 = _mm256_broadcast_ss(va+7);
- _sum4 = _mm256_fmadd_ps(_vb2, _va0, _sum4); // sum4 += (a20-a27) * k42
- _sum5 = _mm256_fmadd_ps(_vb2, _va1, _sum5); // sum5 += (a20-a27) * k52
- _sum6 = _mm256_fmadd_ps(_vb2, _va2, _sum6); // sum6 += (a20-a27) * k62
- _sum7 = _mm256_fmadd_ps(_vb2, _va3, _sum7); // sum7 += (a20-a27) * k72
-
- va += 8;
-
- // k3
- _va0 = _mm256_broadcast_ss(va);
- _va1 = _mm256_broadcast_ss(va+1);
- _va2 = _mm256_broadcast_ss(va+2);
- _va3 = _mm256_broadcast_ss(va+3);
- _sum0 = _mm256_fmadd_ps(_vb3, _va0, _sum0); // sum0 += (a30-a37) * k03
- _sum1 = _mm256_fmadd_ps(_vb3, _va1, _sum1); // sum1 += (a30-a37) * k13
- _sum2 = _mm256_fmadd_ps(_vb3, _va2, _sum2); // sum2 += (a30-a37) * k23
- _sum3 = _mm256_fmadd_ps(_vb3, _va3, _sum3); // sum3 += (a30-a37) * k33
- _va0 = _mm256_broadcast_ss(va+4);
- _va1 = _mm256_broadcast_ss(va+5);
- _va2 = _mm256_broadcast_ss(va+6);
- _va3 = _mm256_broadcast_ss(va+7);
- _sum4 = _mm256_fmadd_ps(_vb3, _va0, _sum4); // sum4 += (a30-a37) * k43
- _sum5 = _mm256_fmadd_ps(_vb3, _va1, _sum5); // sum5 += (a30-a37) * k53
- _sum6 = _mm256_fmadd_ps(_vb3, _va2, _sum6); // sum6 += (a30-a37) * k63
- _sum7 = _mm256_fmadd_ps(_vb3, _va3, _sum7); // sum7 += (a30-a37) * k73
-
- va += 8;
- vb += 32;
- }
-
- for (; k<L; k++)
- {
- // k0
- __m256 _va0 = _mm256_broadcast_ss(va);
- __m256 _va1 = _mm256_broadcast_ss(va+1);
- __m256 _va2 = _mm256_broadcast_ss(va+2);
- __m256 _va3 = _mm256_broadcast_ss(va+3);
- __m256 _va4 = _mm256_broadcast_ss(va+4);
- __m256 _va5 = _mm256_broadcast_ss(va+5);
- __m256 _va6 = _mm256_broadcast_ss(va+6);
- __m256 _va7 = _mm256_broadcast_ss(va+7);
- __m256 _vb0 = _mm256_loadu_ps(vb);
- _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
- _sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10
- _sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20
- _sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30
- _sum4 = _mm256_fmadd_ps(_vb0, _va4, _sum4); // sum4 = (a00-a07) * k40
- _sum5 = _mm256_fmadd_ps(_vb0, _va5, _sum5); // sum5 = (a00-a07) * k50
- _sum6 = _mm256_fmadd_ps(_vb0, _va6, _sum6); // sum6 = (a00-a07) * k60
- _sum7 = _mm256_fmadd_ps(_vb0, _va7, _sum7); // sum7 = (a00-a07) * k70
-
- va += 8;
- vb += 8;
- }
-
- _mm256_storeu_ps(output0, _sum0);
- _mm256_storeu_ps(output1, _sum1);
- _mm256_storeu_ps(output2, _sum2);
- _mm256_storeu_ps(output3, _sum3);
- _mm256_storeu_ps(output4, _sum4);
- _mm256_storeu_ps(output5, _sum5);
- _mm256_storeu_ps(output6, _sum6);
- _mm256_storeu_ps(output7, _sum7);
- #else
- float sum0[8] = {0};
- float sum1[8] = {0};
- float sum2[8] = {0};
- float sum3[8] = {0};
- float sum4[8] = {0};
- float sum5[8] = {0};
- float sum6[8] = {0};
- float sum7[8] = {0};
-
- int k=0;
- for (; k+7<L; k=k+8)
- {
- for (int n=0; n<8; n++)
- {
- sum0[n] += va[0] * vb[n];
- sum1[n] += va[1] * vb[n];
- sum2[n] += va[2] * vb[n];
- sum3[n] += va[3] * vb[n];
- sum4[n] += va[4] * vb[n];
- sum5[n] += va[5] * vb[n];
- sum6[n] += va[6] * vb[n];
- sum7[n] += va[7] * vb[n];
- va += 8;
-
- sum0[n] += va[0] * vb[n+8];
- sum1[n] += va[1] * vb[n+8];
- sum2[n] += va[2] * vb[n+8];
- sum3[n] += va[3] * vb[n+8];
- sum4[n] += va[4] * vb[n+8];
- sum5[n] += va[5] * vb[n+8];
- sum6[n] += va[6] * vb[n+8];
- sum7[n] += va[7] * vb[n+8];
- va += 8;
-
- sum0[n] += va[0] * vb[n+16];
- sum1[n] += va[1] * vb[n+16];
- sum2[n] += va[2] * vb[n+16];
- sum3[n] += va[3] * vb[n+16];
- sum4[n] += va[4] * vb[n+16];
- sum5[n] += va[5] * vb[n+16];
- sum6[n] += va[6] * vb[n+16];
- sum7[n] += va[7] * vb[n+16];
- va += 8;
-
- sum0[n] += va[0] * vb[n+24];
- sum1[n] += va[1] * vb[n+24];
- sum2[n] += va[2] * vb[n+24];
- sum3[n] += va[3] * vb[n+24];
- sum4[n] += va[4] * vb[n+24];
- sum5[n] += va[5] * vb[n+24];
- sum6[n] += va[6] * vb[n+24];
- sum7[n] += va[7] * vb[n+24];
- va += 8;
-
- sum0[n] += va[0] * vb[n+32];
- sum1[n] += va[1] * vb[n+32];
- sum2[n] += va[2] * vb[n+32];
- sum3[n] += va[3] * vb[n+32];
- sum4[n] += va[4] * vb[n+32];
- sum5[n] += va[5] * vb[n+32];
- sum6[n] += va[6] * vb[n+32];
- sum7[n] += va[7] * vb[n+32];
- va += 8;
-
- sum0[n] += va[0] * vb[n+40];
- sum1[n] += va[1] * vb[n+40];
- sum2[n] += va[2] * vb[n+40];
- sum3[n] += va[3] * vb[n+40];
- sum4[n] += va[4] * vb[n+40];
- sum5[n] += va[5] * vb[n+40];
- sum6[n] += va[6] * vb[n+40];
- sum7[n] += va[7] * vb[n+40];
- va += 8;
-
- sum0[n] += va[0] * vb[n+48];
- sum1[n] += va[1] * vb[n+48];
- sum2[n] += va[2] * vb[n+48];
- sum3[n] += va[3] * vb[n+48];
- sum4[n] += va[4] * vb[n+48];
- sum5[n] += va[5] * vb[n+48];
- sum6[n] += va[6] * vb[n+48];
- sum7[n] += va[7] * vb[n+48];
- va += 8;
-
- sum0[n] += va[0] * vb[n+56];
- sum1[n] += va[1] * vb[n+56];
- sum2[n] += va[2] * vb[n+56];
- sum3[n] += va[3] * vb[n+56];
- sum4[n] += va[4] * vb[n+56];
- sum5[n] += va[5] * vb[n+56];
- sum6[n] += va[6] * vb[n+56];
- sum7[n] += va[7] * vb[n+56];
- va -= 56;
- }
-
- va += 64;
- vb += 64;
- }
-
- for (; k<L; k++)
- {
- for (int n=0; n<8; n++)
- {
- sum0[n] += va[0] * vb[n];
- sum1[n] += va[1] * vb[n];
- sum2[n] += va[2] * vb[n];
- sum3[n] += va[3] * vb[n];
- sum4[n] += va[4] * vb[n];
- sum5[n] += va[5] * vb[n];
- sum6[n] += va[6] * vb[n];
- sum7[n] += va[7] * vb[n];
- }
-
- va += 8;
- vb += 8;
- }
-
- for (int n=0; n<8; n++)
- {
- output0[n] = sum0[n] + biasptr[0];
- output1[n] = sum1[n] + biasptr[1];
- output2[n] = sum2[n] + biasptr[2];
- output3[n] = sum3[n] + biasptr[3];
- output4[n] = sum4[n] + biasptr[4];
- output5[n] = sum5[n] + biasptr[5];
- output6[n] = sum6[n] + biasptr[6];
- output7[n] = sum7[n] + biasptr[7];
- }
- #endif // __AVX__
- output0 += 8;
- output1 += 8;
- output2 += 8;
- output3 += 8;
- output4 += 8;
- output5 += 8;
- output6 += 8;
- output7 += 8;
- }
-
- for (; j<N; j++)
- {
- const float* vb = bottom_tm.channel(j/8 + j%8);
- const float* va = kernel_tm.channel(i/8);
-
- #if __AVX__
- __m256 _sum0_7 = _mm256_loadu_ps(biasptr);
- __m256 _sum0 = _mm256_set1_ps(0.0);
- __m256 _sum1 = _mm256_set1_ps(0.0);
- __m256 _sum2 = _mm256_set1_ps(0.0);
- __m256 _sum3 = _mm256_set1_ps(0.0);
-
- int k=0;
- for (; k+3<L; k=k+4)
- {
- __m256 _vb0 = _mm256_broadcast_ss(vb);
- __m256 _vb1 = _mm256_broadcast_ss(vb+1);
- __m256 _vb2 = _mm256_broadcast_ss(vb+2);
- __m256 _vb3 = _mm256_broadcast_ss(vb+3);
- __m256 _va0 = _mm256_loadu_ps(va);
- __m256 _va1 = _mm256_loadu_ps(va+8);
- __m256 _va2 = _mm256_loadu_ps(va+16);
- __m256 _va3 = _mm256_loadu_ps(va+24);
-
- _sum0 = _mm256_fmadd_ps(_va0, _vb0, _sum0);// sum0 += (k00-k70) * a00
- _sum1 = _mm256_fmadd_ps(_va1, _vb1, _sum1);// sum1 += (k01-k71) * a10
- _sum2 = _mm256_fmadd_ps(_va2, _vb2, _sum2);// sum2 += (k02-k72) * a20
- _sum3 = _mm256_fmadd_ps(_va3, _vb3, _sum3);// sum3 += (k03-k73) * a30
-
- va += 32;
- vb += 4;
- }
-
- _sum0 = _mm256_add_ps(_sum0, _sum1);
- _sum2 = _mm256_add_ps(_sum2, _sum3);
- _sum0_7 = _mm256_add_ps(_sum0_7, _sum0);
- _sum0_7 = _mm256_add_ps(_sum0_7, _sum2);
-
- for (; k<L; k++)
- {
- __m256 _vb0 = _mm256_broadcast_ss(vb);
- __m256 _va = _mm256_loadu_ps(va);
-
- _sum0_7 = _mm256_fmadd_ps(_va, _vb0, _sum0_7);// sum0 += (k00-k70) * a00
-
- va += 8;
- vb += 1;
- }
-
- output0[0] = _sum0_7[0];
- output1[0] = _sum0_7[1];
- output2[0] = _sum0_7[2];
- output3[0] = _sum0_7[3];
- output4[0] = _sum0_7[4];
- output5[0] = _sum0_7[5];
- output6[0] = _sum0_7[6];
- output7[0] = _sum0_7[7];
- #else
- float sum0 = biasptr[0];
- float sum1 = biasptr[1];
- float sum2 = biasptr[2];
- float sum3 = biasptr[3];
- float sum4 = biasptr[4];
- float sum5 = biasptr[5];
- float sum6 = biasptr[6];
- float sum7 = biasptr[7];
-
- for (int k=0; k<L; k++)
- {
- sum0 += va[0] * vb[0];
- sum1 += va[1] * vb[0];
- sum2 += va[2] * vb[0];
- sum3 += va[3] * vb[0];
- sum4 += va[4] * vb[0];
- sum5 += va[5] * vb[0];
- sum6 += va[6] * vb[0];
- sum7 += va[7] * vb[0];
-
- va += 8;
- vb += 1;
- }
-
- output0[0] = sum0;
- output1[0] = sum1;
- output2[0] = sum2;
- output3[0] = sum3;
- output4[0] = sum4;
- output5[0] = sum5;
- output6[0] = sum6;
- output7[0] = sum7;
- #endif // __AVX__
- output0++;
- output1++;
- output2++;
- output3++;
- output4++;
- output5++;
- output6++;
- output7++;
- }
- }
-
- nn_outch = (outch - remain_outch_start) >> 2;
-
- #pragma omp parallel for num_threads(opt.num_threads)
- for (int pp=0; pp<nn_outch; pp++)
- {
- int i = remain_outch_start + pp * 4;
-
- float* output0 = top_blob.channel(i);
- float* output1 = top_blob.channel(i+1);
- float* output2 = top_blob.channel(i+2);
- float* output3 = top_blob.channel(i+3);
-
- const float zeros[4] = {0.f, 0.f, 0.f, 0.f};
- const float* biasptr = bias ? bias + i : zeros;
-
- int j=0;
- for (; j+7<N; j=j+8)
- {
- const float* vb = bottom_tm.channel(j/8);
- const float* va = kernel_tm.channel(i/8 + (i%8)/4);
- #if __AVX__
- __m256 _sum0 = _mm256_broadcast_ss(biasptr);
- __m256 _sum1 = _mm256_broadcast_ss(biasptr+1);
- __m256 _sum2 = _mm256_broadcast_ss(biasptr+2);
- __m256 _sum3 = _mm256_broadcast_ss(biasptr+3);
-
- int k=0;
- for (; k+3<L; k=k+4)
- {
- // k0
- __m256 _va0 = _mm256_broadcast_ss(va);
- __m256 _va1 = _mm256_broadcast_ss(va+1);
- __m256 _va2 = _mm256_broadcast_ss(va+2);
- __m256 _va3 = _mm256_broadcast_ss(va+3);
- __m256 _vb0 = _mm256_loadu_ps(vb);
- __m256 _vb1 = _mm256_loadu_ps(vb+8);
- __m256 _vb2 = _mm256_loadu_ps(vb+16);
- __m256 _vb3 = _mm256_loadu_ps(vb+24);
- _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
- _sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10
- _sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20
- _sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30
-
- va += 4;
-
- // k1
- _va0 = _mm256_broadcast_ss(va);
- _va1 = _mm256_broadcast_ss(va+1);
- _va2 = _mm256_broadcast_ss(va+2);
- _va3 = _mm256_broadcast_ss(va+3);
- _sum0 = _mm256_fmadd_ps(_vb1, _va0, _sum0); // sum0 += (a10-a17) * k01
- _sum1 = _mm256_fmadd_ps(_vb1, _va1, _sum1); // sum1 += (a10-a17) * k11
- _sum2 = _mm256_fmadd_ps(_vb1, _va2, _sum2); // sum2 += (a10-a17) * k21
- _sum3 = _mm256_fmadd_ps(_vb1, _va3, _sum3); // sum3 += (a10-a17) * k31
-
- va += 4;
-
- // k2
- _va0 = _mm256_broadcast_ss(va);
- _va1 = _mm256_broadcast_ss(va+1);
- _va2 = _mm256_broadcast_ss(va+2);
- _va3 = _mm256_broadcast_ss(va+3);
- _sum0 = _mm256_fmadd_ps(_vb2, _va0, _sum0); // sum0 += (a20-a27) * k02
- _sum1 = _mm256_fmadd_ps(_vb2, _va1, _sum1); // sum1 += (a20-a27) * k12
- _sum2 = _mm256_fmadd_ps(_vb2, _va2, _sum2); // sum2 += (a20-a27) * k22
- _sum3 = _mm256_fmadd_ps(_vb2, _va3, _sum3); // sum3 += (a20-a27) * k32
-
- va += 4;
-
- // k3
- _va0 = _mm256_broadcast_ss(va);
- _va1 = _mm256_broadcast_ss(va+1);
- _va2 = _mm256_broadcast_ss(va+2);
- _va3 = _mm256_broadcast_ss(va+3);
- _sum0 = _mm256_fmadd_ps(_vb3, _va0, _sum0); // sum0 += (a30-a37) * k03
- _sum1 = _mm256_fmadd_ps(_vb3, _va1, _sum1); // sum1 += (a30-a37) * k13
- _sum2 = _mm256_fmadd_ps(_vb3, _va2, _sum2); // sum2 += (a30-a37) * k23
- _sum3 = _mm256_fmadd_ps(_vb3, _va3, _sum3); // sum3 += (a30-a37) * k33
-
- va += 4;
- vb += 32;
- }
-
- for (; k<L; k++)
- {
- // k0
- __m256 _va0 = _mm256_broadcast_ss(va);
- __m256 _va1 = _mm256_broadcast_ss(va+1);
- __m256 _va2 = _mm256_broadcast_ss(va+2);
- __m256 _va3 = _mm256_broadcast_ss(va+3);
- __m256 _vb0 = _mm256_loadu_ps(vb);
- _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
- _sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10
- _sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20
- _sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30
-
- va += 4;
- vb += 4;
- }
-
- _mm256_storeu_ps(output0, _sum0);
- _mm256_storeu_ps(output1, _sum1);
- _mm256_storeu_ps(output2, _sum2);
- _mm256_storeu_ps(output3, _sum3);
- #else
- float sum0[8] = {0};
- float sum1[8] = {0};
- float sum2[8] = {0};
- float sum3[8] = {0};
-
- int k=0;
- for (; k+7<L; k=k+8)
- {
- for (int n=0; n<8; n++)
- {
- sum0[n] += va[0] * vb[n];
- sum1[n] += va[1] * vb[n];
- sum2[n] += va[2] * vb[n];
- sum3[n] += va[3] * vb[n];
- va += 4;
-
- sum0[n] += va[0] * vb[n+8];
- sum1[n] += va[1] * vb[n+8];
- sum2[n] += va[2] * vb[n+8];
- sum3[n] += va[3] * vb[n+8];
- va += 4;
-
- sum0[n] += va[0] * vb[n+16];
- sum1[n] += va[1] * vb[n+16];
- sum2[n] += va[2] * vb[n+16];
- sum3[n] += va[3] * vb[n+16];
- va += 4;
-
- sum0[n] += va[0] * vb[n+24];
- sum1[n] += va[1] * vb[n+24];
- sum2[n] += va[2] * vb[n+24];
- sum3[n] += va[3] * vb[n+24];
- va += 4;
-
- sum0[n] += va[0] * vb[n+32];
- sum1[n] += va[1] * vb[n+32];
- sum2[n] += va[2] * vb[n+32];
- sum3[n] += va[3] * vb[n+32];
- va += 4;
-
- sum0[n] += va[0] * vb[n+40];
- sum1[n] += va[1] * vb[n+40];
- sum2[n] += va[2] * vb[n+40];
- sum3[n] += va[3] * vb[n+40];
- va += 4;
-
- sum0[n] += va[0] * vb[n+48];
- sum1[n] += va[1] * vb[n+48];
- sum2[n] += va[2] * vb[n+48];
- sum3[n] += va[3] * vb[n+48];
- va += 4;
-
- sum0[n] += va[0] * vb[n+56];
- sum1[n] += va[1] * vb[n+56];
- sum2[n] += va[2] * vb[n+56];
- sum3[n] += va[3] * vb[n+56];
- va -= 28;
- }
-
- va += 32;
- vb += 64;
- }
-
- for (; k<L; k++)
- {
- for (int n=0; n<8; n++)
- {
- sum0[n] += va[0] * vb[n];
- sum1[n] += va[1] * vb[n];
- sum2[n] += va[2] * vb[n];
- sum3[n] += va[3] * vb[n];
- }
-
- va += 4;
- vb += 8;
- }
-
- for (int n=0; n<8; n++)
- {
- output0[n] = sum0[n] + biasptr[0];
- output1[n] = sum1[n] + biasptr[1];
- output2[n] = sum2[n] + biasptr[2];
- output3[n] = sum3[n] + biasptr[3];
- }
- #endif // __AVX__
- output0 += 8;
- output1 += 8;
- output2 += 8;
- output3 += 8;
- }
-
- for (; j<N; j++)
- {
- const float* vb = bottom_tm.channel(j/8 + j%8);
- const float* va = kernel_tm.channel(i/8 + (i%8)/4);
- #if __AVX__
- __m128 _sum0_3 = _mm_loadu_ps(biasptr);
- __m128 _sum0 = _mm_set1_ps(0.0);
- __m128 _sum1 = _mm_set1_ps(0.0);
- __m128 _sum2 = _mm_set1_ps(0.0);
- __m128 _sum3 = _mm_set1_ps(0.0);
-
- int k=0;
- for (; k+3<L; k=k+4)
- {
- __m128 _vb0 = _mm_set1_ps(vb[0]);
- __m128 _vb1 = _mm_set1_ps(vb[1]);
- __m128 _vb2 = _mm_set1_ps(vb[2]);
- __m128 _vb3 = _mm_set1_ps(vb[3]);
- __m128 _va0 = _mm_loadu_ps(va);
- __m128 _va1 = _mm_loadu_ps(va+4);
- __m128 _va2 = _mm_loadu_ps(va+8);
- __m128 _va3 = _mm_loadu_ps(va+12);
-
- _sum0 = _mm_fmadd_ps(_va0, _vb0, _sum0);// sum0 += (k00-k30) * a00
- _sum1 = _mm_fmadd_ps(_va1, _vb1, _sum1);// sum1 += (k01-k31) * a10
- _sum2 = _mm_fmadd_ps(_va2, _vb2, _sum2);// sum2 += (k02-k32) * a20
- _sum3 = _mm_fmadd_ps(_va3, _vb3, _sum3);// sum3 += (k03-k33) * a30
-
- va += 16;
- vb += 4;
- }
-
- _sum0 = _mm_add_ps(_sum0, _sum1);
- _sum2 = _mm_add_ps(_sum2, _sum3);
- _sum0_3 = _mm_add_ps(_sum0_3, _sum0);
- _sum0_3 = _mm_add_ps(_sum0_3, _sum2);
-
- for (; k<L; k++)
- {
- __m128 _vb0 = _mm_set1_ps(vb[0]);
- __m128 _va = _mm_loadu_ps(va);
-
- _sum0_3 = _mm_fmadd_ps(_va, _vb0, _sum0_3);// sum0 += (k00-k30) * a00
-
- va += 4;
- vb += 1;
- }
-
- output0[0] = _sum0_3[0];
- output1[0] = _sum0_3[1];
- output2[0] = _sum0_3[2];
- output3[0] = _sum0_3[3];
- #else
- float sum0 = biasptr[0];
- float sum1 = biasptr[1];
- float sum2 = biasptr[2];
- float sum3 = biasptr[3];
-
- for (int k=0; k<L; k++)
- {
- sum0 += va[0] * vb[0];
- sum1 += va[1] * vb[0];
- sum2 += va[2] * vb[0];
- sum3 += va[3] * vb[0];
-
- va += 4;
- vb += 1;
- }
-
- output0[0] = sum0;
- output1[0] = sum1;
- output2[0] = sum2;
- output3[0] = sum3;
- #endif // __AVX__
- output0++;
- output1++;
- output2++;
- output3++;
- }
- }
-
- remain_outch_start += nn_outch << 2;
-
- #pragma omp parallel for num_threads(opt.num_threads)
- for (int i=remain_outch_start; i<outch; i++)
- {
- float* output = top_blob.channel(i);
-
- const float bias0 = bias ? bias[i] : 0.f;
-
- int j=0;
- for (; j+7<N; j=j+8)
- {
- const float* vb = bottom_tm.channel(j/8);
- const float* va = kernel_tm.channel(i/8 + (i%8)/4 + i%4);
- #if __AVX__
- __m256 _sum0 = _mm256_broadcast_ss(&bias0);
-
- int k=0;
- for (; k+3<L; k=k+4)
- {
- // k0
- __m256 _va0 = _mm256_broadcast_ss(va);
- __m256 _va1 = _mm256_broadcast_ss(va+1);
- __m256 _va2 = _mm256_broadcast_ss(va+2);
- __m256 _va3 = _mm256_broadcast_ss(va+3);
- __m256 _vb0 = _mm256_loadu_ps(vb);
- __m256 _vb1 = _mm256_loadu_ps(vb+8);
- __m256 _vb2 = _mm256_loadu_ps(vb+16);
- __m256 _vb3 = _mm256_loadu_ps(vb+24);
-
- _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
- _sum0 = _mm256_fmadd_ps(_vb1, _va1, _sum0); // sum0 += (a10-a17) * k01
- _sum0 = _mm256_fmadd_ps(_vb2, _va2, _sum0); // sum0 += (a20-a27) * k02
- _sum0 = _mm256_fmadd_ps(_vb3, _va3, _sum0); // sum0 += (a30-a37) * k03
-
- va += 4;
- vb += 32;
- }
-
- for (; k<L; k++)
- {
- // k0
- __m256 _va0 = _mm256_broadcast_ss(va);
- __m256 _vb0 = _mm256_loadu_ps(vb);
-
- _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
-
- va += 1;
- vb += 4;
- }
-
- _mm256_storeu_ps(output, _sum0);
- #else
- float sum[8] = {0};
-
- int k=0;
- for (; k+7<L; k=k+8)
- {
- for (int n=0; n<8; n++)
- {
- sum[n] += va[0] * vb[n];
- sum[n] += va[1] * vb[n+8];
- sum[n] += va[2] * vb[n+16];
- sum[n] += va[3] * vb[n+24];
- sum[n] += va[4] * vb[n+32];
- sum[n] += va[5] * vb[n+40];
- sum[n] += va[6] * vb[n+48];
- sum[n] += va[7] * vb[n+56];
- }
-
- va += 8;
- vb += 64;
- }
-
- for (; k<L; k++)
- {
- for (int n=0; n<8; n++)
- {
- sum[n] += va[0] * vb[n];
- }
-
- va += 1;
- vb += 8;
- }
-
- for (int n=0; n<8; n++)
- {
- output[n] = sum[n] + bias0;
- }
- #endif // __AVX__
- output += 8;
- }
-
- for (; j<N; j++)
- {
- const float* vb = bottom_tm.channel(j/8 + j%8);
- const float* va = kernel_tm.channel(i/8 + (i%8)/4 + i%4);
-
- int k=0;
- #if __AVX__
- __m128 _sum0 = _mm_set1_ps(0.f);
-
- for (; k+3<L; k+=4)
- {
- __m128 _p0 = _mm_loadu_ps(vb);
- vb += 4;
-
- __m128 _k0 = _mm_loadu_ps(va);
- va += 4;
-
- _sum0 = _mm_fmadd_ps(_p0, _k0, _sum0);
- }
- float sum0 = bias0 + _sum0[0] + _sum0[1] + _sum0[2] + _sum0[3];
- #else
- float sum0 = bias0;
- #endif // __AVX__
- for (; k<L; k++)
- {
- sum0 += va[0] * vb[0];
-
- va += 1;
- vb += 1;
- }
- output[0] = sum0;
-
- output++;
- }
- }
- }
- }
- #else
- static void conv_im2col_sgemm_transform_kernel_sse(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size)
- {
- const float* kernel = _kernel;
-
- // kernel memory packed 4 x 4
- kernel_tm.create(4*kernel_size, inch, outch/4 + outch%4);
-
- int nn_outch = 0;
- int remain_outch_start = 0;
-
- nn_outch = outch >> 2;
- remain_outch_start = nn_outch << 2;
-
- for (int pp=0; pp<nn_outch; pp++)
- {
- int p = pp * 4;
-
- const float* k0 = kernel + (p+0)*inch*kernel_size;
- const float* k1 = kernel + (p+1)*inch*kernel_size;
- const float* k2 = kernel + (p+2)*inch*kernel_size;
- const float* k3 = kernel + (p+3)*inch*kernel_size;
-
- float* ktmp = kernel_tm.channel(p/4);
-
- for (int q=0; q<inch*kernel_size; q++)
- {
- ktmp[0] = k0[0];
- ktmp[1] = k1[0];
- ktmp[2] = k2[0];
- ktmp[3] = k3[0];
- ktmp += 4;
-
- k0 += 1;
- k1 += 1;
- k2 += 1;
- k3 += 1;
- }
- }
-
- for (int p=remain_outch_start; p<outch; p++)
- {
- const float* k0 = kernel + (p+0)*inch*kernel_size;
-
- float* ktmp = kernel_tm.channel(p/4 + p%4);
-
- for (int q=0; q<inch*kernel_size; q++)
- {
- ktmp[0] = k0[0];
- ktmp++;
- k0++;
- }
- }
- }
-
- static void conv_im2col_sgemm_sse(const Mat &bottom_blob, Mat &top_blob, const Mat & kernel_tm, const Mat& _bias, \
- const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt)
- {
- int w = bottom_blob.w;
- int inch = bottom_blob.c;
- size_t elemsize = bottom_blob.elemsize;
-
- int outw = top_blob.w;
- int outh = top_blob.h;
- int outch = top_blob.c;
-
- const float* bias = _bias;
-
- // im2col
- Mat bottom_im2col(outw*outh, kernel_h*kernel_w*inch, elemsize, opt.workspace_allocator);
- {
- const int stride = kernel_h*kernel_w*outw*outh;
- float* ret = (float*)bottom_im2col;
-
- #pragma omp parallel for num_threads(opt.num_threads)
- for (int p=0; p<inch; p++)
- {
- const float* input = bottom_blob.channel(p);
- int retID = stride * p;
- for (int u=0; u<kernel_h; u++)
- {
- for (int v=0; v<kernel_w; v++)
- {
- for (int i=0; i<outh; i++)
- {
- for (int j=0; j<outw; j++)
- {
- int row = u + i * stride_h;
- int col = v + j * stride_w;
- int index = row * w + col;
- ret[retID] = input[index];
- retID++;
- }
- }
- }
- }
- }
- }
-
- int kernel_size = kernel_w * kernel_h;
- int out_size = outw * outh;
-
- // bottom_im2col memory packed 4 x 4
- Mat bottom_tm(4*kernel_size, inch, out_size/4 + out_size%4, elemsize, opt.workspace_allocator);
- {
- int nn_size = out_size >> 2;
- int remain_size_start = nn_size << 2;
-
- #pragma omp parallel for num_threads(opt.num_threads)
- for (int ii=0; ii<nn_size; ii++)
- {
- int i = ii * 4;
-
- const float* img0 = bottom_im2col.channel(0);
- img0 += i;
-
- float* tmpptr = bottom_tm.channel(i/4);
-
- for (int q=0; q<inch*kernel_size; q++)
- {
- #if __SSE__
- _mm_storeu_ps(tmpptr, _mm_loadu_ps(img0));
- #else
- tmpptr[0] = img0[0];
- tmpptr[1] = img0[1];
- tmpptr[2] = img0[2];
- tmpptr[3] = img0[3];
- #endif // __SSE__
- tmpptr += 4;
- img0 += out_size;
- }
- }
-
- #pragma omp parallel for num_threads(opt.num_threads)
- for (int i=remain_size_start; i<out_size; i++)
- {
- const float* img0 = bottom_im2col.channel(0);
- img0 += i;
-
- float* tmpptr = bottom_tm.channel(i/4 + i%4);
-
- for (int q=0; q<inch*kernel_size; q++)
- {
- tmpptr[0] = img0[0];
-
- tmpptr += 1;
- img0 += out_size;
- }
- }
- }
-
- // sgemm(int M, int N, int L, float* A, float* B, float* C)
- {
- //int M = outch; // outch
- int N = outw * outh; // outsize or out stride
- int L = kernel_w * kernel_h * inch; // ksize * inch
-
- int nn_outch = 0;
- int remain_outch_start = 0;
-
- nn_outch = outch >> 2;
- remain_outch_start = nn_outch << 2;
-
- #pragma omp parallel for num_threads(opt.num_threads)
- for (int pp=0; pp<nn_outch; pp++)
- {
- int i = pp * 4;
-
- float* output0 = top_blob.channel(i);
- float* output1 = top_blob.channel(i+1);
- float* output2 = top_blob.channel(i+2);
- float* output3 = top_blob.channel(i+3);
-
- const float zeros[4] = {0.f, 0.f, 0.f, 0.f};
- const float* biasptr = bias ? bias + i : zeros;
-
- int j=0;
- for (; j+3<N; j=j+4)
- {
- const float* vb = bottom_tm.channel(j/4);
- const float* va = kernel_tm.channel(i/4);
- #if __SSE__
- __m128 _sum0 = _mm_set1_ps(biasptr[0]);
- __m128 _sum1 = _mm_set1_ps(biasptr[1]);
- __m128 _sum2 = _mm_set1_ps(biasptr[2]);
- __m128 _sum3 = _mm_set1_ps(biasptr[3]);
-
- int k=0;
- for (; k+3<L; k=k+4)
- {
- // k0
- __m128 _vb = _mm_loadu_ps(vb);
- __m128 _va0 = _mm_set1_ps(va[0]);
- __m128 _va1 = _mm_set1_ps(va[1]);
- __m128 _va2 = _mm_set1_ps(va[2]);
- __m128 _va3 = _mm_set1_ps(va[3]);
- _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0));// sum0 = (a00-a03) * k00
- _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1));// sum1 = (a00-a03) * k10
- _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2));// sum2 = (a00-a03) * k20
- _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3));// sum3 = (a00-a03) * k30
-
- // k1
- _vb = _mm_loadu_ps(vb+4);
- _va0 = _mm_set1_ps(va[4]);
- _va1 = _mm_set1_ps(va[5]);
- _va2 = _mm_set1_ps(va[6]);
- _va3 = _mm_set1_ps(va[7]);
- _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0));// sum0 = (a10-a13) * k01
- _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1));// sum1 = (a10-a13) * k11
- _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2));// sum2 = (a10-a13) * k21
- _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3));// sum3 = (a10-a13) * k31
-
- // k2
- _vb = _mm_loadu_ps(vb+8);
- _va0 = _mm_set1_ps(va[8]);
- _va1 = _mm_set1_ps(va[9]);
- _va2 = _mm_set1_ps(va[10]);
- _va3 = _mm_set1_ps(va[11]);
- _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0));// sum0 = (a20-a23) * k02
- _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1));// sum1 = (a20-a23) * k12
- _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2));// sum2 = (a20-a23) * k22
- _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3));// sum3 = (a20-a23) * k32
-
- // k3
- _vb = _mm_loadu_ps(vb+12);
- _va0 = _mm_set1_ps(va[12]);
- _va1 = _mm_set1_ps(va[13]);
- _va2 = _mm_set1_ps(va[14]);
- _va3 = _mm_set1_ps(va[15]);
- _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0));// sum0 = (a30-a33) * k03
- _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1));// sum1 = (a30-a33) * k13
- _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2));// sum2 = (a30-a33) * k23
- _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3));// sum3 = (a30-a33) * k33
-
- va += 16;
- vb += 16;
- }
-
- for (; k<L; k++)
- {
- // k0
- __m128 _vb = _mm_loadu_ps(vb);
- __m128 _va0 = _mm_set1_ps(va[0]);
- __m128 _va1 = _mm_set1_ps(va[1]);
- __m128 _va2 = _mm_set1_ps(va[2]);
- __m128 _va3 = _mm_set1_ps(va[3]);
- _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0));// sum0 = (a00-a03) * k00
- _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1));// sum1 = (a00-a03) * k10
- _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2));// sum2 = (a00-a03) * k20
- _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3));// sum3 = (a00-a03) * k30
-
- va += 4;
- vb += 4;
- }
- _mm_storeu_ps(output0, _sum0);
- _mm_storeu_ps(output1, _sum1);
- _mm_storeu_ps(output2, _sum2);
- _mm_storeu_ps(output3, _sum3);
- #else
- float sum0[4] = {0};
- float sum1[4] = {0};
- float sum2[4] = {0};
- float sum3[4] = {0};
-
- int k=0;
- for (; k+7<L; k=k+8)
- {
- for (int n=0; n<4; n++)
- {
- sum0[n] += va[0] * vb[n];
- sum1[n] += va[1] * vb[n];
- sum2[n] += va[2] * vb[n];
- sum3[n] += va[3] * vb[n];
- va += 4;
-
- sum0[n] += va[0] * vb[n+4];
- sum1[n] += va[1] * vb[n+4];
- sum2[n] += va[2] * vb[n+4];
- sum3[n] += va[3] * vb[n+4];
- va += 4;
-
- sum0[n] += va[0] * vb[n+8];
- sum1[n] += va[1] * vb[n+8];
- sum2[n] += va[2] * vb[n+8];
- sum3[n] += va[3] * vb[n+8];
- va += 4;
-
- sum0[n] += va[0] * vb[n+12];
- sum1[n] += va[1] * vb[n+12];
- sum2[n] += va[2] * vb[n+12];
- sum3[n] += va[3] * vb[n+12];
- va += 4;
-
- sum0[n] += va[0] * vb[n+16];
- sum1[n] += va[1] * vb[n+16];
- sum2[n] += va[2] * vb[n+16];
- sum3[n] += va[3] * vb[n+16];
- va += 4;
-
- sum0[n] += va[0] * vb[n+20];
- sum1[n] += va[1] * vb[n+20];
- sum2[n] += va[2] * vb[n+20];
- sum3[n] += va[3] * vb[n+20];
- va += 4;
-
- sum0[n] += va[0] * vb[n+24];
- sum1[n] += va[1] * vb[n+24];
- sum2[n] += va[2] * vb[n+24];
- sum3[n] += va[3] * vb[n+24];
- va += 4;
-
- sum0[n] += va[0] * vb[n+28];
- sum1[n] += va[1] * vb[n+28];
- sum2[n] += va[2] * vb[n+28];
- sum3[n] += va[3] * vb[n+28];
- va -= 28;
- }
-
- va += 32;
- vb += 32;
- }
-
- for (; k<L; k++)
- {
- for (int n=0; n<4; n++)
- {
- sum0[n] += va[0] * vb[n];
- sum1[n] += va[1] * vb[n];
- sum2[n] += va[2] * vb[n];
- sum3[n] += va[3] * vb[n];
- }
-
- va += 4;
- vb += 4;
- }
-
- for (int n=0; n<4; n++)
- {
- output0[n] = sum0[n] + biasptr[0];
- output1[n] = sum1[n] + biasptr[1];
- output2[n] = sum2[n] + biasptr[2];
- output3[n] = sum3[n] + biasptr[3];
- }
- #endif // __SSE__
- output0 += 4;
- output1 += 4;
- output2 += 4;
- output3 += 4;
- }
-
- for (; j<N; j++)
- {
- const float* vb = bottom_tm.channel(j/4 + j%4);
- const float* va = kernel_tm.channel(i/4);
- #if __SSE__
- __m128 _sum0_3 = _mm_loadu_ps(biasptr);
- __m128 _sum0 = _mm_set1_ps(0.0);
- __m128 _sum1 = _mm_set1_ps(0.0);
- __m128 _sum2 = _mm_set1_ps(0.0);
- __m128 _sum3 = _mm_set1_ps(0.0);
-
- int k=0;
- for (; k+3<L; k=k+4)
- {
- __m128 _vb0 = _mm_set1_ps(vb[0]);
- __m128 _vb1 = _mm_set1_ps(vb[1]);
- __m128 _vb2 = _mm_set1_ps(vb[2]);
- __m128 _vb3 = _mm_set1_ps(vb[3]);
- __m128 _va0 = _mm_loadu_ps(va);
- __m128 _va1 = _mm_loadu_ps(va+4);
- __m128 _va2 = _mm_loadu_ps(va+8);
- __m128 _va3 = _mm_loadu_ps(va+12);
-
- _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_va0, _vb0));// sum0 += (k00-k30) * a00
- _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_va1, _vb1));// sum1 += (k01-k31) * a10
- _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_va2, _vb2));// sum2 += (k02-k32) * a20
- _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_va3, _vb3));// sum3 += (k03-k33) * a30
-
- va += 16;
- vb += 4;
- }
-
- _sum0 = _mm_add_ps(_sum0, _sum1);
- _sum2 = _mm_add_ps(_sum2, _sum3);
- _sum0_3 = _mm_add_ps(_sum0_3, _sum0);
- _sum0_3 = _mm_add_ps(_sum0_3, _sum2);
-
- for (; k<L; k++)
- {
- __m128 _vb0 = _mm_set1_ps(vb[0]);
- __m128 _va = _mm_loadu_ps(va);
-
- _sum0_3 = _mm_add_ps(_sum0_3, _mm_mul_ps(_va, _vb0));// sum0 += (k00-k30) * a00
-
- va += 4;
- vb += 1;
- }
- output0[0] = _sum0_3[0];
- output1[0] = _sum0_3[1];
- output2[0] = _sum0_3[2];
- output3[0] = _sum0_3[3];
- #else
- float sum0 = biasptr[0];
- float sum1 = biasptr[1];
- float sum2 = biasptr[2];
- float sum3 = biasptr[3];
-
- for (int k=0; k<L; k++)
- {
- sum0 += va[0] * vb[0];
- sum1 += va[1] * vb[0];
- sum2 += va[2] * vb[0];
- sum3 += va[3] * vb[0];
-
- va += 4;
- vb += 1;
- }
-
- output0[0] = sum0;
- output1[0] = sum1;
- output2[0] = sum2;
- output3[0] = sum3;
- #endif // __SSE__
- output0++;
- output1++;
- output2++;
- output3++;
- }
- }
-
- #pragma omp parallel for num_threads(opt.num_threads)
- for (int i=remain_outch_start; i<outch; i++)
- {
- float* output = top_blob.channel(i);
-
- const float bias0 = bias ? bias[i] : 0.f;
-
- int j=0;
- for (; j+3<N; j=j+4)
- {
- const float* vb = bottom_tm.channel(j/4);
- const float* va = kernel_tm.channel(i/4 + i%4);
- #if __SSE__
- __m128 _sum0 = _mm_set1_ps(bias0);
-
- int k=0;
- for (; k+3<L; k=k+4)
- {
- // k0
- __m128 _va0 = _mm_set1_ps(va[0]);
- __m128 _va1 = _mm_set1_ps(va[1]);
- __m128 _va2 = _mm_set1_ps(va[2]);
- __m128 _va3 = _mm_set1_ps(va[3]);
- __m128 _vb0 = _mm_loadu_ps(vb);
- __m128 _vb1 = _mm_loadu_ps(vb+4);
- __m128 _vb2 = _mm_loadu_ps(vb+8);
- __m128 _vb3 = _mm_loadu_ps(vb+12);
-
- _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb0, _va0));// sum0 = (a00-a03) * k00
- _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb1, _va1));// sum0 += (a10-a13) * k01
- _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb2, _va2));// sum0 += (a20-a23) * k02
- _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb3, _va3));// sum0 += (a30-a33) * k03
-
- va += 4;
- vb += 16;
- }
-
- for (; k<L; k++)
- {
- // k0
- __m128 _va0 = _mm_set1_ps(va[0]);
- __m128 _vb0 = _mm_loadu_ps(vb);
-
- _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb0, _va0)); // sum0 = (a00-a03) * k00
-
- va += 1;
- vb += 4;
- }
- _mm_storeu_ps(output, _sum0);
- #else
- float sum[4] = {0};
-
- int k=0;
- for (; k+3<L; k=k+4)
- {
- for (int n=0; n<4; n++)
- {
- sum[n] += va[0] * vb[n];
- sum[n] += va[1] * vb[n+4];
- sum[n] += va[2] * vb[n+8];
- sum[n] += va[3] * vb[n+12];
- sum[n] += va[4] * vb[n+16];
- sum[n] += va[5] * vb[n+20];
- sum[n] += va[6] * vb[n+24];
- sum[n] += va[7] * vb[n+28];
- }
-
- va += 8;
- vb += 32;
- }
-
- for (; k<L; k++)
- {
- for (int n=0; n<4; n++)
- {
- sum[n] += va[0] * vb[n];
- }
-
- va += 1;
- vb += 4;
- }
-
- for (int n=0; n<4; n++)
- {
- output[n] = sum[n] + bias0;
- }
- #endif // __SSE__
- output += 4;
- }
-
- for (; j<N; j++)
- {
- const float* vb = bottom_tm.channel(j/4 + j%4);
- const float* va = kernel_tm.channel(i/4 + i%4);
-
- int k=0;
- #if __SSE__
- __m128 _sum0 = _mm_set1_ps(0.f);
-
- for (; k+3<L; k+=4)
- {
- __m128 _p0 = _mm_loadu_ps(vb);
- __m128 _k0 = _mm_loadu_ps(va);
- _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_p0, _k0));
-
- va += 4;
- vb += 4;
- }
- float sum0 = bias0 + _sum0[0] + _sum0[1] + _sum0[2] + _sum0[3];
- #else
- float sum0 = bias0;
- #endif // __SSE__
- for (; k<L; k++)
- {
- sum0 += va[0] * vb[0];
-
- va += 1;
- vb += 1;
- }
- output[0] = sum0;
-
- output++;
- }
- }
- }
- }
- #endif
|