Browse Source

unroll outch for convolution 1x1 stride 1

tags/20171225
nihuini 8 years ago
parent
commit
55ec189998
1 changed files with 467 additions and 4 deletions
  1. +467
    -4
      src/layer/arm/convolution_1x1.h

+ 467
- 4
src/layer/arm/convolution_1x1.h View File

@@ -27,8 +27,471 @@ static void conv1x1s1_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& _ke
const float* kernel = _kernel;
const float* bias = _bias;

int nn_outch = outch >> 2;
int remain_outch_start = nn_outch << 2;

#pragma omp parallel for
for (int p=0; p<outch; p++)
for (int pp=0; pp<nn_outch; pp++)
{
int p = pp * 4;

Mat out0 = top_blob.channel(p);
Mat out1 = top_blob.channel(p+1);
Mat out2 = top_blob.channel(p+2);
Mat out3 = top_blob.channel(p+3);

const float bias0 = bias ? bias[p] : 0.f;
const float bias1 = bias ? bias[p+1] : 0.f;
const float bias2 = bias ? bias[p+2] : 0.f;
const float bias3 = bias ? bias[p+3] : 0.f;

out0.fill(bias0);
out1.fill(bias1);
out2.fill(bias2);
out3.fill(bias3);

int q = 0;

for (; q+3<inch; q+=4)
{
float* outptr0 = out0;
float* outptr1 = out1;
float* outptr2 = out2;
float* outptr3 = out3;

const float* img0 = bottom_blob.channel(q);
const float* img1 = bottom_blob.channel(q+1);
const float* img2 = bottom_blob.channel(q+2);
const float* img3 = bottom_blob.channel(q+3);

const float* kernel0 = kernel + p*inch + q;
const float* kernel1 = kernel + (p+1)*inch + q;
const float* kernel2 = kernel + (p+2)*inch + q;
const float* kernel3 = kernel + (p+3)*inch + q;

const float* r0 = img0;
const float* r1 = img1;
const float* r2 = img2;
const float* r3 = img3;

int size = outw * outh;

#if __ARM_NEON
int nn = size >> 3;
int remain = size & 7;
#else
int remain = size;
#endif // __ARM_NEON

#if __ARM_NEON
float32x4_t _k0 = vld1q_f32(kernel0);
float32x4_t _k1 = vld1q_f32(kernel1);
float32x4_t _k2 = vld1q_f32(kernel2);
float32x4_t _k3 = vld1q_f32(kernel3);
#if __aarch64__
for (; nn>0; nn--)
{
float32x4_t _p = vld1q_f32(r0);
float32x4_t _pn = vld1q_f32(r0+4);

float32x4_t _out0p = vld1q_f32(outptr0);
float32x4_t _out0pn = vld1q_f32(outptr0+4);

float32x4_t _out1p = vld1q_f32(outptr1);
float32x4_t _out1pn = vld1q_f32(outptr1+4);

float32x4_t _out2p = vld1q_f32(outptr2);
float32x4_t _out2pn = vld1q_f32(outptr2+4);

float32x4_t _out3p = vld1q_f32(outptr3);
float32x4_t _out3pn = vld1q_f32(outptr3+4);

_out0p = vfmaq_laneq_f32(_out0p, _p, _k0, 0);
_out0pn = vfmaq_laneq_f32(_out0pn, _pn, _k0, 0);

_out1p = vfmaq_laneq_f32(_out1p, _p, _k1, 0);
_out1pn = vfmaq_laneq_f32(_out1pn, _pn, _k1, 0);

_out2p = vfmaq_laneq_f32(_out2p, _p, _k2, 0);
_out2pn = vfmaq_laneq_f32(_out2pn, _pn, _k2, 0);

_out3p = vfmaq_laneq_f32(_out3p, _p, _k3, 0);
_out3pn = vfmaq_laneq_f32(_out3pn, _pn, _k3, 0);

float32x4_t _p1 = vld1q_f32(r1);
float32x4_t _p1n = vld1q_f32(r1+4);

_out0p = vfmaq_laneq_f32(_out0p, _p1, _k0, 1);
_out0pn = vfmaq_laneq_f32(_out0pn, _p1n, _k0, 1);

_out1p = vfmaq_laneq_f32(_out1p, _p1, _k1, 1);
_out1pn = vfmaq_laneq_f32(_out1pn, _p1n, _k1, 1);

_out2p = vfmaq_laneq_f32(_out2p, _p1, _k2, 1);
_out2pn = vfmaq_laneq_f32(_out2pn, _p1n, _k2, 1);

_out3p = vfmaq_laneq_f32(_out3p, _p1, _k3, 1);
_out3pn = vfmaq_laneq_f32(_out3pn, _p1n, _k3, 1);

float32x4_t _p2 = vld1q_f32(r2);
float32x4_t _p2n = vld1q_f32(r2+4);

_out0p = vfmaq_laneq_f32(_out0p, _p2, _k0, 2);
_out0pn = vfmaq_laneq_f32(_out0pn, _p2n, _k0, 2);

_out1p = vfmaq_laneq_f32(_out1p, _p2, _k1, 2);
_out1pn = vfmaq_laneq_f32(_out1pn, _p2n, _k1, 2);

_out2p = vfmaq_laneq_f32(_out2p, _p2, _k2, 2);
_out2pn = vfmaq_laneq_f32(_out2pn, _p2n, _k2, 2);

_out3p = vfmaq_laneq_f32(_out3p, _p2, _k3, 2);
_out3pn = vfmaq_laneq_f32(_out3pn, _p2n, _k3, 2);

float32x4_t _p3 = vld1q_f32(r3);
float32x4_t _p3n = vld1q_f32(r3+4);

_out0p = vfmaq_laneq_f32(_out0p, _p3, _k0, 3);
_out0pn = vfmaq_laneq_f32(_out0pn, _p3n, _k0, 3);

_out1p = vfmaq_laneq_f32(_out1p, _p3, _k1, 3);
_out1pn = vfmaq_laneq_f32(_out1pn, _p3n, _k1, 3);

_out2p = vfmaq_laneq_f32(_out2p, _p3, _k2, 3);
_out2pn = vfmaq_laneq_f32(_out2pn, _p3n, _k2, 3);

_out3p = vfmaq_laneq_f32(_out3p, _p3, _k3, 3);
_out3pn = vfmaq_laneq_f32(_out3pn, _p3n, _k3, 3);

vst1q_f32(outptr0, _out0p);
vst1q_f32(outptr0+4, _out0pn);

vst1q_f32(outptr1, _out1p);
vst1q_f32(outptr1+4, _out1pn);

vst1q_f32(outptr2, _out2p);
vst1q_f32(outptr2+4, _out2pn);

vst1q_f32(outptr3, _out3p);
vst1q_f32(outptr3+4, _out3pn);

r0 += 8;
r1 += 8;
r2 += 8;
r3 += 8;
outptr0 += 8;
outptr1 += 8;
outptr2 += 8;
outptr3 += 8;
}
#else
if (nn > 0)
{
asm volatile(
"pld [%5, #256] \n"
"vld1.f32 {d12-d15}, [%5 :128]! \n"
"0: \n"
"pld [%1, #256] \n"
"vld1.f32 {d16-d19}, [%1 :128] \n"

"vmla.f32 q8, q6, %e18[0] \n"
"vmla.f32 q9, q7, %e18[0] \n"

"pld [%2, #256] \n"
"vld1.f32 {d20-d23}, [%2 :128] \n"

"vmla.f32 q10, q6, %e19[0] \n"
"vmla.f32 q11, q7, %e19[0] \n"

"pld [%3, #256] \n"
"vld1.f32 {d24-d27}, [%3 :128] \n"

"vmla.f32 q12, q6, %e20[0] \n"
"vmla.f32 q13, q7, %e20[0] \n"

"pld [%4, #256] \n"
"vld1.f32 {d28-d31}, [%4 :128] \n"

"vmla.f32 q14, q6, %e21[0] \n"
"vmla.f32 q15, q7, %e21[0] \n"

"pld [%6, #256] \n"
"vld1.f32 {d12-d15}, [%6 :128]! \n"

"vmla.f32 q8, q6, %e18[1] \n"
"vmla.f32 q9, q7, %e18[1] \n"

"vmla.f32 q10, q6, %e19[1] \n"
"vmla.f32 q11, q7, %e19[1] \n"

"vmla.f32 q12, q6, %e20[1] \n"
"vmla.f32 q13, q7, %e20[1] \n"

"vmla.f32 q14, q6, %e21[1] \n"
"vmla.f32 q15, q7, %e21[1] \n"

"pld [%7, #256] \n"
"vld1.f32 {d12-d15}, [%7 :128]! \n"

"vmla.f32 q8, q6, %f18[0] \n"
"vmla.f32 q9, q7, %f18[0] \n"

"vmla.f32 q10, q6, %f19[0] \n"
"vmla.f32 q11, q7, %f19[0] \n"

"vmla.f32 q12, q6, %f20[0] \n"
"vmla.f32 q13, q7, %f20[0] \n"

"vmla.f32 q14, q6, %f21[0] \n"
"vmla.f32 q15, q7, %f21[0] \n"

"pld [%8, #256] \n"
"vld1.f32 {d12-d15}, [%8 :128]! \n"

"vmla.f32 q8, q6, %f18[1] \n"
"vmla.f32 q9, q7, %f18[1] \n"

"vmla.f32 q10, q6, %f19[1] \n"
"vmla.f32 q11, q7, %f19[1] \n"

"vst1.f32 {d16-d19}, [%1 :128]! \n"

"vmla.f32 q12, q6, %f20[1] \n"
"vmla.f32 q13, q7, %f20[1] \n"

"vst1.f32 {d20-d23}, [%2 :128]! \n"

"vmla.f32 q14, q6, %f21[1] \n"
"vmla.f32 q15, q7, %f21[1] \n"

"vst1.f32 {d24-d27}, [%3 :128]! \n"

"pld [%5, #256] \n"
"vld1.f32 {d12-d15}, [%5 :128]! \n"

"subs %0, #1 \n"
"vst1.f32 {d28-d31}, [%4 :128]! \n"

"bne 0b \n"
"sub %5, #32 \n"
: "=r"(nn), // %0
"=r"(outptr0),// %1
"=r"(outptr1),// %2
"=r"(outptr2),// %3
"=r"(outptr3),// %4
"=r"(r0), // %5
"=r"(r1), // %6
"=r"(r2), // %7
"=r"(r3) // %8
: "0"(nn),
"1"(outptr0),
"2"(outptr1),
"3"(outptr2),
"4"(outptr3),
"5"(r0),
"6"(r1),
"7"(r2),
"8"(r3),
"w"(_k0), // %18
"w"(_k1), // %19
"w"(_k2), // %20
"w"(_k3) // %21
: "cc", "memory", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"
);
}
#endif // __aarch64__
#endif // __ARM_NEON
for (; remain>0; remain--)
{
// TODO neon optimize
float sum0 = *r0 * kernel0[0] + *r1 * kernel0[1] + *r2 * kernel0[2] + *r3 * kernel0[3];
float sum1 = *r0 * kernel1[0] + *r1 * kernel1[1] + *r2 * kernel1[2] + *r3 * kernel1[3];
float sum2 = *r0 * kernel2[0] + *r1 * kernel2[1] + *r2 * kernel2[2] + *r3 * kernel2[3];
float sum3 = *r0 * kernel3[0] + *r1 * kernel3[1] + *r2 * kernel3[2] + *r3 * kernel3[3];

*outptr0 += sum0;
*outptr1 += sum1;
*outptr2 += sum2;
*outptr3 += sum3;

r0++;
r1++;
r2++;
r3++;
outptr0++;
outptr1++;
outptr2++;
outptr3++;
}
}

for (; q<inch; q++)
{
float* outptr0 = out0;
float* outptr1 = out1;
float* outptr2 = out2;
float* outptr3 = out3;

const float* img0 = bottom_blob.channel(q);

const float* kernel0 = kernel + p*inch + q;
const float* kernel1 = kernel + (p+1)*inch + q;
const float* kernel2 = kernel + (p+2)*inch + q;
const float* kernel3 = kernel + (p+3)*inch + q;

const float k0 = kernel0[0];
const float k1 = kernel1[0];
const float k2 = kernel2[0];
const float k3 = kernel3[0];

const float* r0 = img0;

int size = outw * outh;

#if __ARM_NEON
int nn = size >> 3;
int remain = size & 7;
#else
int remain = size;
#endif // __ARM_NEON

#if __ARM_NEON
float32x4_t _k0 = vdupq_n_f32(k0);
float32x4_t _k1 = vdupq_n_f32(k1);
float32x4_t _k2 = vdupq_n_f32(k2);
float32x4_t _k3 = vdupq_n_f32(k3);
#if __aarch64__
for (; nn>0; nn--)
{
float32x4_t _p = vld1q_f32(r0);
float32x4_t _pn = vld1q_f32(r0+4);

float32x4_t _out0p = vld1q_f32(outptr0);
float32x4_t _out0pn = vld1q_f32(outptr0+4);

float32x4_t _out1p = vld1q_f32(outptr1);
float32x4_t _out1pn = vld1q_f32(outptr1+4);

float32x4_t _out2p = vld1q_f32(outptr2);
float32x4_t _out2pn = vld1q_f32(outptr2+4);

float32x4_t _out3p = vld1q_f32(outptr3);
float32x4_t _out3pn = vld1q_f32(outptr3+4);

_out0p = vfmaq_f32(_out0p, _p, _k0);
_out0pn = vfmaq_f32(_out0pn, _pn, _k0);

_out1p = vfmaq_f32(_out1p, _p, _k1);
_out1pn = vfmaq_f32(_out1pn, _pn, _k1);

_out2p = vfmaq_f32(_out2p, _p, _k2);
_out2pn = vfmaq_f32(_out2pn, _pn, _k2);

_out3p = vfmaq_f32(_out3p, _p, _k3);
_out3pn = vfmaq_f32(_out3pn, _pn, _k3);

vst1q_f32(outptr0, _out0p);
vst1q_f32(outptr0+4, _out0pn);

vst1q_f32(outptr1, _out1p);
vst1q_f32(outptr1+4, _out1pn);

vst1q_f32(outptr2, _out2p);
vst1q_f32(outptr2+4, _out2pn);

vst1q_f32(outptr3, _out3p);
vst1q_f32(outptr3+4, _out3pn);

r0 += 8;
outptr0 += 8;
outptr1 += 8;
outptr2 += 8;
outptr3 += 8;
}
#else
if (nn > 0)
{
asm volatile(
"pld [%5, #256] \n"
"vld1.f32 {d12-d15}, [%5 :128]! \n"
"0: \n"
"pld [%1, #256] \n"
"vld1.f32 {d16-d19}, [%1 :128] \n"
"vmla.f32 q8, q6, %q12 \n"
"vmla.f32 q9, q7, %q12 \n"

"pld [%2, #256] \n"
"vld1.f32 {d20-d23}, [%2 :128] \n"
"vmla.f32 q10, q2, %q13 \n"
"vmla.f32 q11, q3, %q13 \n"

"vst1.f32 {d16-d19}, [%1 :128]! \n"

"pld [%2, #256] \n"
"vld1.f32 {d24-d27}, [%2 :128] \n"
"vmla.f32 q12, q2, %q14 \n"
"vmla.f32 q13, q3, %q14 \n"

"vst1.f32 {d20-d23}, [%2 :128]! \n"

"pld [%2, #256] \n"
"vld1.f32 {d28-d31}, [%2 :128] \n"
"vmla.f32 q14, q2, %q15 \n"
"vmla.f32 q15, q3, %q15 \n"

"vst1.f32 {d24-d27}, [%3 :128]! \n"

"pld [%5, #256] \n"
"vld1.f32 {d12-d15}, [%5 :128]! \n"
"subs %0, #1 \n"
"vst1.f32 {d28-d31}, [%4 :128]! \n"
"bne 0b \n"
"sub %5, #32 \n"
: "=r"(nn), // %0
"=r"(outptr0),// %1
"=r"(outptr1),// %2
"=r"(outptr2),// %3
"=r"(outptr3),// %4
"=r"(r0) // %5
: "0"(nn),
"1"(outptr0),
"2"(outptr1),
"3"(outptr2),
"4"(outptr3),
"5"(r0),
"w"(_k0), // %12
"w"(_k1), // %13
"w"(_k2), // %14
"w"(_k3) // %15
: "cc", "memory", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"
);
}
#endif // __aarch64__
#endif // __ARM_NEON
for (; remain>0; remain--)
{
// TODO neon optimize
float sum0 = *r0 * k0;
float sum1 = *r0 * k1;
float sum2 = *r0 * k2;
float sum3 = *r0 * k3;

*outptr0 += sum0;
*outptr1 += sum1;
*outptr2 += sum2;
*outptr3 += sum3;

r0++;
outptr0++;
outptr1++;
outptr2++;
outptr3++;
}
}
}

#pragma omp parallel for
for (int p=remain_outch_start; p<outch; p++)
{
Mat out = top_blob.channel(p);

@@ -47,7 +510,7 @@ static void conv1x1s1_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& _ke
const float* img2 = bottom_blob.channel(q+2);
const float* img3 = bottom_blob.channel(q+3);

const float* kernel0 = kernel + p*inch + q;
const float* kernel0 = kernel + p*inch + q;
const float k0 = kernel0[0];
const float k1 = kernel0[1];
const float k2 = kernel0[2];
@@ -297,7 +760,7 @@ static void conv1x1s2_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& _ke
const float* img2 = bottom_blob.channel(q+2);
const float* img3 = bottom_blob.channel(q+3);

const float* kernel0 = kernel + p*inch + q;
const float* kernel0 = kernel + p*inch + q;
const float k0 = kernel0[0];
const float k1 = kernel0[1];
const float k2 = kernel0[2];
@@ -454,7 +917,7 @@ static void conv1x1s2_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& _ke

const float* img0 = bottom_blob.channel(q);

const float* kernel0 = kernel + p*inch + q;
const float* kernel0 = kernel + p*inch + q;
const float k0 = kernel0[0];

const float* r0 = img0;


Loading…
Cancel
Save