Browse Source

optimize conv sgemm with sse on intel platform (#1035)

* optimize conv sgemm with sse

* Update convolution_x86.cpp
tags/20190611
BUG1989 nihui 7 years ago
parent
commit
c2022f4501
3 changed files with 652 additions and 14 deletions
  1. +90
    -13
      src/layer/x86/convolution_3x3.h
  2. +558
    -0
      src/layer/x86/convolution_sgemm.h
  3. +4
    -1
      src/layer/x86/convolution_x86.cpp

+ 90
- 13
src/layer/x86/convolution_3x3.h View File

@@ -1224,6 +1224,7 @@ static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, cons
{
const float* kptr = kernel_tm_test[r].channel(p/8);
const float* r0 = bottom_blob_tm.channel(tiles*r+i);
#if __AVX__ || __SSE__
#if __AVX__
float zero_val = 0.f;
__m128 _sum0 = _mm_broadcast_ss(&zero_val);
@@ -1234,9 +1235,17 @@ static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, cons
__m128 _sum5 = _mm_broadcast_ss(&zero_val);
__m128 _sum6 = _mm_broadcast_ss(&zero_val);
__m128 _sum7 = _mm_broadcast_ss(&zero_val);

#else
__m128 _sum0 = _mm_set1_ps(0.f);
__m128 _sum1 = _mm_set1_ps(0.f);
__m128 _sum2 = _mm_set1_ps(0.f);
__m128 _sum3 = _mm_set1_ps(0.f);
__m128 _sum4 = _mm_set1_ps(0.f);
__m128 _sum5 = _mm_set1_ps(0.f);
__m128 _sum6 = _mm_set1_ps(0.f);
__m128 _sum7 = _mm_set1_ps(0.f);
#endif
int q=0;

for (; q+3<inch; q=q+4)
{
__m128 _r0 = _mm_loadu_ps(r0);
@@ -1252,6 +1261,7 @@ static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, cons
__m128 _k5 = _mm_loadu_ps(kptr+20);
__m128 _k6 = _mm_loadu_ps(kptr+24);
__m128 _k7 = _mm_loadu_ps(kptr+28);
#if __AVX__
_sum0 = _mm_fmadd_ps(_r0, _k0, _sum0);
_sum1 = _mm_fmadd_ps(_r0, _k1, _sum1);
_sum2 = _mm_fmadd_ps(_r0, _k2, _sum2);
@@ -1260,7 +1270,16 @@ static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, cons
_sum5 = _mm_fmadd_ps(_r0, _k5, _sum5);
_sum6 = _mm_fmadd_ps(_r0, _k6, _sum6);
_sum7 = _mm_fmadd_ps(_r0, _k7, _sum7);

#else
_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r0, _k0));
_sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r0, _k1));
_sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r0, _k2));
_sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r0, _k3));
_sum4 = _mm_add_ps(_sum4, _mm_mul_ps(_r0, _k4));
_sum5 = _mm_add_ps(_sum5, _mm_mul_ps(_r0, _k5));
_sum6 = _mm_add_ps(_sum6, _mm_mul_ps(_r0, _k6));
_sum7 = _mm_add_ps(_sum7, _mm_mul_ps(_r0, _k7));
#endif
kptr += 32;
_k0 = _mm_loadu_ps(kptr);
_k1 = _mm_loadu_ps(kptr+4);
@@ -1270,6 +1289,7 @@ static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, cons
_k5 = _mm_loadu_ps(kptr+20);
_k6 = _mm_loadu_ps(kptr+24);
_k7 = _mm_loadu_ps(kptr+28);
#if __AVX__
_sum0 = _mm_fmadd_ps(_r1, _k0, _sum0);
_sum1 = _mm_fmadd_ps(_r1, _k1, _sum1);
_sum2 = _mm_fmadd_ps(_r1, _k2, _sum2);
@@ -1278,6 +1298,16 @@ static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, cons
_sum5 = _mm_fmadd_ps(_r1, _k5, _sum5);
_sum6 = _mm_fmadd_ps(_r1, _k6, _sum6);
_sum7 = _mm_fmadd_ps(_r1, _k7, _sum7);
#else
_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r1, _k0));
_sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r1, _k1));
_sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r1, _k2));
_sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r1, _k3));
_sum4 = _mm_add_ps(_sum4, _mm_mul_ps(_r1, _k4));
_sum5 = _mm_add_ps(_sum5, _mm_mul_ps(_r1, _k5));
_sum6 = _mm_add_ps(_sum6, _mm_mul_ps(_r1, _k6));
_sum7 = _mm_add_ps(_sum7, _mm_mul_ps(_r1, _k7));
#endif

kptr += 32;
_k0 = _mm_loadu_ps(kptr);
@@ -1288,6 +1318,7 @@ static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, cons
_k5 = _mm_loadu_ps(kptr+20);
_k6 = _mm_loadu_ps(kptr+24);
_k7 = _mm_loadu_ps(kptr+28);
#if __AVX__
_sum0 = _mm_fmadd_ps(_r2, _k0, _sum0);
_sum1 = _mm_fmadd_ps(_r2, _k1, _sum1);
_sum2 = _mm_fmadd_ps(_r2, _k2, _sum2);
@@ -1296,7 +1327,16 @@ static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, cons
_sum5 = _mm_fmadd_ps(_r2, _k5, _sum5);
_sum6 = _mm_fmadd_ps(_r2, _k6, _sum6);
_sum7 = _mm_fmadd_ps(_r2, _k7, _sum7);

#else
_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r2, _k0));
_sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r2, _k1));
_sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r2, _k2));
_sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r2, _k3));
_sum4 = _mm_add_ps(_sum4, _mm_mul_ps(_r2, _k4));
_sum5 = _mm_add_ps(_sum5, _mm_mul_ps(_r2, _k5));
_sum6 = _mm_add_ps(_sum6, _mm_mul_ps(_r2, _k6));
_sum7 = _mm_add_ps(_sum7, _mm_mul_ps(_r2, _k7));
#endif
kptr += 32;
_k0 = _mm_loadu_ps(kptr);
_k1 = _mm_loadu_ps(kptr+4);
@@ -1306,6 +1346,7 @@ static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, cons
_k5 = _mm_loadu_ps(kptr+20);
_k6 = _mm_loadu_ps(kptr+24);
_k7 = _mm_loadu_ps(kptr+28);
#if __AVX__
_sum0 = _mm_fmadd_ps(_r3, _k0, _sum0);
_sum1 = _mm_fmadd_ps(_r3, _k1, _sum1);
_sum2 = _mm_fmadd_ps(_r3, _k2, _sum2);
@@ -1314,7 +1355,16 @@ static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, cons
_sum5 = _mm_fmadd_ps(_r3, _k5, _sum5);
_sum6 = _mm_fmadd_ps(_r3, _k6, _sum6);
_sum7 = _mm_fmadd_ps(_r3, _k7, _sum7);

#else
_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r3, _k0));
_sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r3, _k1));
_sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r3, _k2));
_sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r3, _k3));
_sum4 = _mm_add_ps(_sum4, _mm_mul_ps(_r3, _k4));
_sum5 = _mm_add_ps(_sum5, _mm_mul_ps(_r3, _k5));
_sum6 = _mm_add_ps(_sum6, _mm_mul_ps(_r3, _k6));
_sum7 = _mm_add_ps(_sum7, _mm_mul_ps(_r3, _k7));
#endif
kptr += 32;
r0 += 16;
}
@@ -1331,6 +1381,7 @@ static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, cons
__m128 _k6 = _mm_loadu_ps(kptr+24);
__m128 _k7 = _mm_loadu_ps(kptr+28);

#if __AVX__
_sum0 = _mm_fmadd_ps(_r0, _k0, _sum0);
_sum1 = _mm_fmadd_ps(_r0, _k1, _sum1);
_sum2 = _mm_fmadd_ps(_r0, _k2, _sum2);
@@ -1339,6 +1390,16 @@ static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, cons
_sum5 = _mm_fmadd_ps(_r0, _k5, _sum5);
_sum6 = _mm_fmadd_ps(_r0, _k6, _sum6);
_sum7 = _mm_fmadd_ps(_r0, _k7, _sum7);
#else
_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r0, _k0));
_sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r0, _k1));
_sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r0, _k2));
_sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r0, _k3));
_sum4 = _mm_add_ps(_sum4, _mm_mul_ps(_r0, _k4));
_sum5 = _mm_add_ps(_sum5, _mm_mul_ps(_r0, _k5));
_sum6 = _mm_add_ps(_sum6, _mm_mul_ps(_r0, _k6));
_sum7 = _mm_add_ps(_sum7, _mm_mul_ps(_r0, _k7));
#endif

kptr += 32;
r0 += 4;
@@ -1390,7 +1451,7 @@ static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, cons
output6_tm[n] = sum6[n];
output7_tm[n] = sum7[n];
}
#endif // __ARM_NEON
#endif // __AVX__
output0_tm += 36;
output1_tm += 36;
output2_tm += 36;
@@ -1422,13 +1483,19 @@ static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, cons
{
const float* kptr = kernel_tm_test[r].channel(p/8 + (p%8)/4);
const float* r0 = bottom_blob_tm.channel(tiles*r+i);
#if __AVX__ || __SSE__
#if __AVX__
float zero_val = 0.f;
__m128 _sum0 = _mm_broadcast_ss(&zero_val);
__m128 _sum1 = _mm_broadcast_ss(&zero_val);
__m128 _sum2 = _mm_broadcast_ss(&zero_val);
__m128 _sum3 = _mm_broadcast_ss(&zero_val);

#else
__m128 _sum0 = _mm_set1_ps(0.f);
__m128 _sum1 = _mm_set1_ps(0.f);
__m128 _sum2 = _mm_set1_ps(0.f);
__m128 _sum3 = _mm_set1_ps(0.f);
#endif
for (int q=0; q<inch; q++)
{
__m128 _r0 = _mm_loadu_ps(r0);
@@ -1436,12 +1503,17 @@ static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, cons
__m128 _k1 = _mm_loadu_ps(kptr+4);
__m128 _k2 = _mm_loadu_ps(kptr+8);
__m128 _k3 = _mm_loadu_ps(kptr+12);
#if __AVX__
_sum0 = _mm_fmadd_ps(_r0, _k0, _sum0);
_sum1 = _mm_fmadd_ps(_r0, _k1, _sum1);
_sum2 = _mm_fmadd_ps(_r0, _k2, _sum2);
_sum3 = _mm_fmadd_ps(_r0, _k3, _sum3);

#else
_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r0, _k0));
_sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r0, _k1));
_sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r0, _k2));
_sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r0, _k3));
#endif
kptr += 16;
r0 += 4;
}
@@ -1496,21 +1568,26 @@ static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, cons
{
const float* kptr = kernel_tm_test[r].channel(p/8 + (p%8)/4 + p%4);
const float* r0 = bottom_blob_tm.channel(tiles*r+i);
#if __AVX__ || __SSE__
#if __AVX__
float zero_val = 0.f;
__m128 _sum0 = _mm_broadcast_ss(&zero_val);
#else
__m128 _sum0 = _mm_set1_ps(0.f);
#endif

for (int q=0; q<inch; q++)
{
__m128 _r0 = _mm_loadu_ps(r0);
__m128 _k0 = _mm_loadu_ps(kptr);
#if __AVX__
_sum0 = _mm_fmadd_ps(_r0, _k0, _sum0);

#else
_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r0, _k0));
#endif
kptr += 16;
r0 += 4;
}

_mm_storeu_ps(output0_tm, _sum0);
#else
float sum0[4] = {0};
@@ -1529,7 +1606,7 @@ static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, cons
{
output0_tm[n] = sum0[n];
}
#endif // __AVX__
#endif // __AVX__ || __SSE__
output0_tm += 36;
}
}


+ 558
- 0
src/layer/x86/convolution_sgemm.h View File

@@ -12,6 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#if __AVX__
static void conv_im2col_sgemm_transform_kernel_sse(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size)
{

@@ -1024,3 +1025,560 @@ static void conv_im2col_sgemm_sse(const Mat &bottom_blob, Mat &top_blob, const M
}
}
}
#else
static void conv_im2col_sgemm_transform_kernel_sse(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size)
{
const float* kernel = _kernel;

// kernel memory packed 4 x 4
kernel_tm.create(4*kernel_size, inch, outch/4 + outch%4);
int nn_outch = 0;
int remain_outch_start = 0;

nn_outch = outch >> 2;
remain_outch_start = nn_outch << 2;

for (int pp=0; pp<nn_outch; pp++)
{
int p = pp * 4;

const float* k0 = kernel + (p+0)*inch*kernel_size;
const float* k1 = kernel + (p+1)*inch*kernel_size;
const float* k2 = kernel + (p+2)*inch*kernel_size;
const float* k3 = kernel + (p+3)*inch*kernel_size;

float* ktmp = kernel_tm.channel(p/4);

for (int q=0; q<inch*kernel_size; q++)
{
ktmp[0] = k0[0];
ktmp[1] = k1[0];
ktmp[2] = k2[0];
ktmp[3] = k3[0];
ktmp += 4;

k0 += 1;
k1 += 1;
k2 += 1;
k3 += 1;
}
}
for (int p=remain_outch_start; p<outch; p++)
{
const float* k0 = kernel + (p+0)*inch*kernel_size;

float* ktmp = kernel_tm.channel(p/4 + p%4);

for (int q=0; q<inch*kernel_size; q++)
{
ktmp[0] = k0[0];
ktmp++;
k0++;
}
}
}

static void conv_im2col_sgemm_sse(const Mat &bottom_blob, Mat &top_blob, const Mat & kernel_tm, const Mat& _bias, \
const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt)
{
int w = bottom_blob.w;
int inch = bottom_blob.c;
size_t elemsize = bottom_blob.elemsize;

int outw = top_blob.w;
int outh = top_blob.h;
int outch = top_blob.c;

const float* bias = _bias;

// im2col
Mat bottom_im2col(outw*outh, kernel_h*kernel_w*inch, elemsize, opt.workspace_allocator);
{
const int stride = kernel_h*kernel_w*outw*outh;
float* ret = (float*)bottom_im2col;
#pragma omp parallel for num_threads(opt.num_threads)
for (int p=0; p<inch; p++)
{
const float* input = bottom_blob.channel(p);
int retID = stride * p;
for (int u=0; u<kernel_h; u++)
{
for (int v=0; v<kernel_w; v++)
{
for (int i=0; i<outh; i++)
{
for (int j=0; j<outw; j++)
{
int row = u + i * stride_h;
int col = v + j * stride_w;
int index = row * w + col;
ret[retID] = input[index];
retID++;
}
}
}
}
}
}

int kernel_size = kernel_w * kernel_h;
int out_size = outw * outh;

// bottom_im2col memory packed 4 x 4
Mat bottom_tm(4*kernel_size, inch, out_size/4 + out_size%4, elemsize, opt.workspace_allocator);
{
int nn_size = out_size >> 2;
int remain_size_start = nn_size << 2;

#pragma omp parallel for num_threads(opt.num_threads)
for (int ii=0; ii<nn_size; ii++)
{
int i = ii * 4;

const float* img0 = bottom_im2col.channel(0);
img0 += i;

float* tmpptr = bottom_tm.channel(i/4);

for (int q=0; q<inch*kernel_size; q++)
{
#if __SSE__
_mm_storeu_ps(tmpptr, _mm_loadu_ps(img0));
#else
tmpptr[0] = img0[0];
tmpptr[1] = img0[1];
tmpptr[2] = img0[2];
tmpptr[3] = img0[3];
#endif // __SSE__
tmpptr += 4;
img0 += out_size;
}
}

#pragma omp parallel for num_threads(opt.num_threads)
for (int i=remain_size_start; i<out_size; i++)
{
const float* img0 = bottom_im2col.channel(0);
img0 += i;

float* tmpptr = bottom_tm.channel(i/4 + i%4);

for (int q=0; q<inch*kernel_size; q++)
{
tmpptr[0] = img0[0];

tmpptr += 1;
img0 += out_size;
}
}
}
// sgemm(int M, int N, int L, float* A, float* B, float* C)
{
//int M = outch; // outch
int N = outw * outh; // outsize or out stride
int L = kernel_w * kernel_h * inch; // ksize * inch

int nn_outch = 0;
int remain_outch_start = 0;

nn_outch = outch >> 2;
remain_outch_start = nn_outch << 2;

#pragma omp parallel for num_threads(opt.num_threads)
for (int pp=0; pp<nn_outch; pp++)
{
int i = pp * 4;

float* output0 = top_blob.channel(i);
float* output1 = top_blob.channel(i+1);
float* output2 = top_blob.channel(i+2);
float* output3 = top_blob.channel(i+3);

const float zeros[4] = {0.f, 0.f, 0.f, 0.f};
const float* biasptr = bias ? bias + i : zeros;

int j=0;
for (; j+3<N; j=j+4)
{
const float* vb = bottom_tm.channel(j/4);
const float* va = kernel_tm.channel(i/4);
#if __SSE__
__m128 _sum0 = _mm_set1_ps(biasptr[0]);
__m128 _sum1 = _mm_set1_ps(biasptr[1]);
__m128 _sum2 = _mm_set1_ps(biasptr[2]);
__m128 _sum3 = _mm_set1_ps(biasptr[3]);

int k=0;
for (; k+3<L; k=k+4)
{
// k0
__m128 _vb = _mm_loadu_ps(vb);
__m128 _va0 = _mm_set1_ps(va[0]);
__m128 _va1 = _mm_set1_ps(va[1]);
__m128 _va2 = _mm_set1_ps(va[2]);
__m128 _va3 = _mm_set1_ps(va[3]);
_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0));// sum0 = (a00-a03) * k00
_sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1));// sum1 = (a00-a03) * k10
_sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2));// sum2 = (a00-a03) * k20
_sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3));// sum3 = (a00-a03) * k30

// k1
_vb = _mm_loadu_ps(vb+4);
_va0 = _mm_set1_ps(va[4]);
_va1 = _mm_set1_ps(va[5]);
_va2 = _mm_set1_ps(va[6]);
_va3 = _mm_set1_ps(va[7]);
_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0));// sum0 = (a10-a13) * k01
_sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1));// sum1 = (a10-a13) * k11
_sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2));// sum2 = (a10-a13) * k21
_sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3));// sum3 = (a10-a13) * k31

// k2
_vb = _mm_loadu_ps(vb+8);
_va0 = _mm_set1_ps(va[8]);
_va1 = _mm_set1_ps(va[9]);
_va2 = _mm_set1_ps(va[10]);
_va3 = _mm_set1_ps(va[11]);
_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0));// sum0 = (a20-a23) * k02
_sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1));// sum1 = (a20-a23) * k12
_sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2));// sum2 = (a20-a23) * k22
_sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3));// sum3 = (a20-a23) * k32

// k3
_vb = _mm_loadu_ps(vb+12);
_va0 = _mm_set1_ps(va[12]);
_va1 = _mm_set1_ps(va[13]);
_va2 = _mm_set1_ps(va[14]);
_va3 = _mm_set1_ps(va[15]);
_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0));// sum0 = (a30-a33) * k03
_sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1));// sum1 = (a30-a33) * k13
_sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2));// sum2 = (a30-a33) * k23
_sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3));// sum3 = (a30-a33) * k33

va += 16;
vb += 16;
}

for (; k<L; k++)
{
// k0
__m128 _vb = _mm_loadu_ps(vb);
__m128 _va0 = _mm_set1_ps(va[0]);
__m128 _va1 = _mm_set1_ps(va[1]);
__m128 _va2 = _mm_set1_ps(va[2]);
__m128 _va3 = _mm_set1_ps(va[3]);
_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0));// sum0 = (a00-a03) * k00
_sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1));// sum1 = (a00-a03) * k10
_sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2));// sum2 = (a00-a03) * k20
_sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3));// sum3 = (a00-a03) * k30
va += 4;
vb += 4;
}
_mm_storeu_ps(output0, _sum0);
_mm_storeu_ps(output1, _sum1);
_mm_storeu_ps(output2, _sum2);
_mm_storeu_ps(output3, _sum3);
#else
float sum0[4] = {0};
float sum1[4] = {0};
float sum2[4] = {0};
float sum3[4] = {0};
int k=0;
for (; k+7<L; k=k+8)
{
for (int n=0; n<4; n++)
{
sum0[n] += va[0] * vb[n];
sum1[n] += va[1] * vb[n];
sum2[n] += va[2] * vb[n];
sum3[n] += va[3] * vb[n];
va += 4;

sum0[n] += va[0] * vb[n+4];
sum1[n] += va[1] * vb[n+4];
sum2[n] += va[2] * vb[n+4];
sum3[n] += va[3] * vb[n+4];
va += 4;

sum0[n] += va[0] * vb[n+8];
sum1[n] += va[1] * vb[n+8];
sum2[n] += va[2] * vb[n+8];
sum3[n] += va[3] * vb[n+8];
va += 4;

sum0[n] += va[0] * vb[n+12];
sum1[n] += va[1] * vb[n+12];
sum2[n] += va[2] * vb[n+12];
sum3[n] += va[3] * vb[n+12];
va += 4;

sum0[n] += va[0] * vb[n+16];
sum1[n] += va[1] * vb[n+16];
sum2[n] += va[2] * vb[n+16];
sum3[n] += va[3] * vb[n+16];
va += 4;

sum0[n] += va[0] * vb[n+20];
sum1[n] += va[1] * vb[n+20];
sum2[n] += va[2] * vb[n+20];
sum3[n] += va[3] * vb[n+20];
va += 4;

sum0[n] += va[0] * vb[n+24];
sum1[n] += va[1] * vb[n+24];
sum2[n] += va[2] * vb[n+24];
sum3[n] += va[3] * vb[n+24];
va += 4;

sum0[n] += va[0] * vb[n+28];
sum1[n] += va[1] * vb[n+28];
sum2[n] += va[2] * vb[n+28];
sum3[n] += va[3] * vb[n+28];
va -= 28;
}

va += 32;
vb += 32;
}

for (; k<L; k++)
{
for (int n=0; n<4; n++)
{
sum0[n] += va[0] * vb[n];
sum1[n] += va[1] * vb[n];
sum2[n] += va[2] * vb[n];
sum3[n] += va[3] * vb[n];
}
va += 4;
vb += 4;
}

for (int n=0; n<4; n++)
{
output0[n] = sum0[n] + biasptr[0];
output1[n] = sum1[n] + biasptr[1];
output2[n] = sum2[n] + biasptr[2];
output3[n] = sum3[n] + biasptr[3];
}
#endif // __SSE__
output0 += 4;
output1 += 4;
output2 += 4;
output3 += 4;
}

for (; j<N; j++)
{
const float* vb = bottom_tm.channel(j/4 + j%4);
const float* va = kernel_tm.channel(i/4);
#if __SSE__
__m128 _sum0_3 = _mm_loadu_ps(biasptr);
__m128 _sum0 = _mm_set1_ps(0.0);
__m128 _sum1 = _mm_set1_ps(0.0);
__m128 _sum2 = _mm_set1_ps(0.0);
__m128 _sum3 = _mm_set1_ps(0.0);

int k=0;
for (; k+3<L; k=k+4)
{
__m128 _vb0 = _mm_set1_ps(vb[0]);
__m128 _vb1 = _mm_set1_ps(vb[1]);
__m128 _vb2 = _mm_set1_ps(vb[2]);
__m128 _vb3 = _mm_set1_ps(vb[3]);
__m128 _va0 = _mm_loadu_ps(va);
__m128 _va1 = _mm_loadu_ps(va+4);
__m128 _va2 = _mm_loadu_ps(va+8);
__m128 _va3 = _mm_loadu_ps(va+12);

_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_va0, _vb0));// sum0 += (k00-k30) * a00
_sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_va1, _vb1));// sum1 += (k01-k31) * a10
_sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_va2, _vb2));// sum2 += (k02-k32) * a20
_sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_va3, _vb3));// sum3 += (k03-k33) * a30

va += 16;
vb += 4;
}

_sum0 = _mm_add_ps(_sum0, _sum1);
_sum2 = _mm_add_ps(_sum2, _sum3);
_sum0_3 = _mm_add_ps(_sum0_3, _sum0);
_sum0_3 = _mm_add_ps(_sum0_3, _sum2);

for (; k<L; k++)
{
__m128 _vb0 = _mm_set1_ps(vb[0]);
__m128 _va = _mm_loadu_ps(va);

_sum0_3 = _mm_add_ps(_sum0_3, _mm_mul_ps(_va, _vb0));// sum0 += (k00-k30) * a00

va += 4;
vb += 1;
}
output0[0] = _sum0_3[0];
output1[0] = _sum0_3[1];
output2[0] = _sum0_3[2];
output3[0] = _sum0_3[3];
#else
float sum0 = biasptr[0];
float sum1 = biasptr[1];
float sum2 = biasptr[2];
float sum3 = biasptr[3];

for (int k=0; k<L; k++)
{
sum0 += va[0] * vb[0];
sum1 += va[1] * vb[0];
sum2 += va[2] * vb[0];
sum3 += va[3] * vb[0];

va += 4;
vb += 1;
}
output0[0] = sum0;
output1[0] = sum1;
output2[0] = sum2;
output3[0] = sum3;
#endif // __SSE__
output0++;
output1++;
output2++;
output3++;
}
}

#pragma omp parallel for num_threads(opt.num_threads)
for (int i=remain_outch_start; i<outch; i++)
{
float* output = top_blob.channel(i);

const float bias0 = bias ? bias[i] : 0.f;

int j=0;
for (; j+3<N; j=j+4)
{
const float* vb = bottom_tm.channel(j/4);
const float* va = kernel_tm.channel(i/4 + i%4);
#if __SSE__
__m128 _sum0 = _mm_set1_ps(bias0);

int k=0;
for (; k+3<L; k=k+4)
{
// k0
__m128 _va0 = _mm_set1_ps(va[0]);
__m128 _va1 = _mm_set1_ps(va[1]);
__m128 _va2 = _mm_set1_ps(va[2]);
__m128 _va3 = _mm_set1_ps(va[3]);
__m128 _vb0 = _mm_loadu_ps(vb);
__m128 _vb1 = _mm_loadu_ps(vb+4);
__m128 _vb2 = _mm_loadu_ps(vb+8);
__m128 _vb3 = _mm_loadu_ps(vb+12);

_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb0, _va0));// sum0 = (a00-a03) * k00
_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb1, _va1));// sum0 += (a10-a13) * k01
_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb2, _va2));// sum0 += (a20-a23) * k02
_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb3, _va3));// sum0 += (a30-a33) * k03
va += 4;
vb += 16;
}

for (; k<L; k++)
{
// k0
__m128 _va0 = _mm_set1_ps(va[0]);
__m128 _vb0 = _mm_loadu_ps(vb);

_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb0, _va0)); // sum0 = (a00-a03) * k00

va += 1;
vb += 4;
}
_mm_storeu_ps(output, _sum0);
#else
float sum[4] = {0};

int k=0;
for (; k+3<L; k=k+4)
{
for (int n=0; n<4; n++)
{
sum[n] += va[0] * vb[n];
sum[n] += va[1] * vb[n+4];
sum[n] += va[2] * vb[n+8];
sum[n] += va[3] * vb[n+12];
sum[n] += va[4] * vb[n+16];
sum[n] += va[5] * vb[n+20];
sum[n] += va[6] * vb[n+24];
sum[n] += va[7] * vb[n+28];
}

va += 8;
vb += 32;
}

for (; k<L; k++)
{
for (int n=0; n<4; n++)
{
sum[n] += va[0] * vb[n];
}

va += 1;
vb += 4;
}

for (int n=0; n<4; n++)
{
output[n] = sum[n] + bias0;
}
#endif // __SSE__
output += 4;
}

for (; j<N; j++)
{
const float* vb = bottom_tm.channel(j/4 + j%4);
const float* va = kernel_tm.channel(i/4 + i%4);

int k=0;
#if __SSE__
__m128 _sum0 = _mm_set1_ps(0.f);

for (; k+3<L; k+=4)
{
__m128 _p0 = _mm_loadu_ps(vb);
__m128 _k0 = _mm_loadu_ps(va);
_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_p0, _k0));

va += 4;
vb += 4;
}
float sum0 = bias0 + _sum0[0] + _sum0[1] + _sum0[2] + _sum0[3];
#else
float sum0 = bias0;
#endif // __SSE__
for (; k<L; k++)
{
sum0 += va[0] * vb[0];

va += 1;
vb += 1;
}
output[0] = sum0;

output++;
}
}
}
}
#endif

+ 4
- 1
src/layer/x86/convolution_x86.cpp View File

@@ -15,7 +15,10 @@
#include "convolution_x86.h"

#include "platform.h"
#if NCNN_AVX2
#if __SSE2__
#include <emmintrin.h>
#endif
#if __AVX__
#include <immintrin.h>
#endif



Loading…
Cancel
Save