Browse Source

unroll outch for convolution 3x3 winograd64, reduce memory usage

tags/20180129
nihuini 8 years ago
parent
commit
9280a068fe
1 changed files with 523 additions and 1 deletions
  1. +523
    -1
      src/layer/arm/convolution_3x3.h

+ 523
- 1
src/layer/arm/convolution_3x3.h View File

@@ -683,6 +683,7 @@ static void conv3x3s1_winograd64_neon(const Mat& bottom_blob, Mat& top_blob, con
}

}
bottom_blob_bordered = Mat();
// END transform input

// BEGIN dot
@@ -692,8 +693,528 @@ static void conv3x3s1_winograd64_neon(const Mat& bottom_blob, Mat& top_blob, con
int h_tm = outh / 6 * 8;
top_blob_tm.create(8*8, w_tm/8 * h_tm/8, outch);

int nn_outch = outch >> 2;
int remain_outch_start = nn_outch << 2;

#pragma omp parallel for
for (int p = 0; p<outch; p++)
for (int pp=0; pp<nn_outch; pp++)
{
int p = pp * 4;

Mat out0_tm = top_blob_tm.channel(p);
Mat out1_tm = top_blob_tm.channel(p+1);
Mat out2_tm = top_blob_tm.channel(p+2);
Mat out3_tm = top_blob_tm.channel(p+3);
const Mat kernel0_tm = kernel_tm.channel(p);
const Mat kernel1_tm = kernel_tm.channel(p+1);
const Mat kernel2_tm = kernel_tm.channel(p+2);
const Mat kernel3_tm = kernel_tm.channel(p+3);

out0_tm.fill(0.f);
out1_tm.fill(0.f);
out2_tm.fill(0.f);
out3_tm.fill(0.f);

int q = 0;
for (; q+3<inch; q+=4)
{
const float* r0 = bottom_blob_tm.channel(q);
const float* r1 = bottom_blob_tm.channel(q+1);
const float* r2 = bottom_blob_tm.channel(q+2);
const float* r3 = bottom_blob_tm.channel(q+3);

const float* k00 = kernel0_tm.row(q);
const float* k10 = kernel1_tm.row(q);
const float* k20 = kernel2_tm.row(q);
const float* k30 = kernel3_tm.row(q);

float* output0_tm = out0_tm;
float* output1_tm = out1_tm;
float* output2_tm = out2_tm;
float* output3_tm = out3_tm;

// tile
for (int i=0; i<h_tm/8 * w_tm/8; i++)
{
#if __ARM_NEON
#if __aarch64__
for (int m=0; m+7<64; m+=8)
{
float32x4_t _output0_tm = vld1q_f32(output0_tm);
float32x4_t _output1_tm = vld1q_f32(output1_tm);
float32x4_t _output2_tm = vld1q_f32(output2_tm);
float32x4_t _output3_tm = vld1q_f32(output3_tm);

float32x4_t _r0 = vld1q_f32(r0);
float32x4_t _r1 = vld1q_f32(r1);
float32x4_t _r2 = vld1q_f32(r2);
float32x4_t _r3 = vld1q_f32(r3);

float32x4_t _k00 = vld1q_f32(k00);
k00 += 64;
float32x4_t _k01 = vld1q_f32(k00);
k00 += 64;
float32x4_t _k02 = vld1q_f32(k00);
k00 += 64;
float32x4_t _k03 = vld1q_f32(k00);
k00 += 64;

k00 -= 64*4;

_output0_tm = vmlaq_f32(_output0_tm, _r0, _k00);
_output0_tm = vmlaq_f32(_output0_tm, _r1, _k01);
_output0_tm = vmlaq_f32(_output0_tm, _r2, _k02);
_output0_tm = vmlaq_f32(_output0_tm, _r3, _k03);

float32x4_t _k10 = vld1q_f32(k10);
k10 += 64;
float32x4_t _k11 = vld1q_f32(k10);
k10 += 64;
float32x4_t _k12 = vld1q_f32(k10);
k10 += 64;
float32x4_t _k13 = vld1q_f32(k10);
k10 += 64;

k10 -= 64*4;

_output1_tm = vmlaq_f32(_output1_tm, _r0, _k10);
_output1_tm = vmlaq_f32(_output1_tm, _r1, _k11);
_output1_tm = vmlaq_f32(_output1_tm, _r2, _k12);
_output1_tm = vmlaq_f32(_output1_tm, _r3, _k13);

float32x4_t _k20 = vld1q_f32(k20);
k20 += 64;
float32x4_t _k21 = vld1q_f32(k20);
k20 += 64;
float32x4_t _k22 = vld1q_f32(k20);
k20 += 64;
float32x4_t _k23 = vld1q_f32(k20);
k20 += 64;

k20 -= 64*4;

_output2_tm = vmlaq_f32(_output2_tm, _r0, _k20);
_output2_tm = vmlaq_f32(_output2_tm, _r1, _k21);
_output2_tm = vmlaq_f32(_output2_tm, _r2, _k22);
_output2_tm = vmlaq_f32(_output2_tm, _r3, _k23);

float32x4_t _k30 = vld1q_f32(k30);
k30 += 64;
float32x4_t _k31 = vld1q_f32(k30);
k30 += 64;
float32x4_t _k32 = vld1q_f32(k30);
k30 += 64;
float32x4_t _k33 = vld1q_f32(k30);
k30 += 64;

k30 -= 64*4;

_output3_tm = vmlaq_f32(_output3_tm, _r0, _k30);
_output3_tm = vmlaq_f32(_output3_tm, _r1, _k31);
_output3_tm = vmlaq_f32(_output3_tm, _r2, _k32);
_output3_tm = vmlaq_f32(_output3_tm, _r3, _k33);

vst1q_f32(output0_tm, _output0_tm);
vst1q_f32(output1_tm, _output1_tm);
vst1q_f32(output2_tm, _output2_tm);
vst1q_f32(output3_tm, _output3_tm);

output0_tm += 4;
output1_tm += 4;
output2_tm += 4;
output3_tm += 4;

r0 += 4;
r1 += 4;
r2 += 4;
r3 += 4;

k00 += 4;
k10 += 4;
k20 += 4;
k30 += 4;

float32x4_t _output0_tmn = vld1q_f32(output0_tm);
float32x4_t _output1_tmn = vld1q_f32(output1_tm);
float32x4_t _output2_tmn = vld1q_f32(output2_tm);
float32x4_t _output3_tmn = vld1q_f32(output3_tm);

float32x4_t _r0n = vld1q_f32(r0);
float32x4_t _r1n = vld1q_f32(r1);
float32x4_t _r2n = vld1q_f32(r2);
float32x4_t _r3n = vld1q_f32(r3);

float32x4_t _k00n = vld1q_f32(k00);
k00 += 64;
float32x4_t _k01n = vld1q_f32(k00);
k00 += 64;
float32x4_t _k02n = vld1q_f32(k00);
k00 += 64;
float32x4_t _k03n = vld1q_f32(k00);
k00 += 64;

k00 -= 64*4;

_output0_tmn = vmlaq_f32(_output0_tmn, _r0n, _k00n);
_output0_tmn = vmlaq_f32(_output0_tmn, _r1n, _k01n);
_output0_tmn = vmlaq_f32(_output0_tmn, _r2n, _k02n);
_output0_tmn = vmlaq_f32(_output0_tmn, _r3n, _k03n);

float32x4_t _k10n = vld1q_f32(k10);
k10 += 64;
float32x4_t _k11n = vld1q_f32(k10);
k10 += 64;
float32x4_t _k12n = vld1q_f32(k10);
k10 += 64;
float32x4_t _k13n = vld1q_f32(k10);
k10 += 64;

k10 -= 64*4;

_output1_tmn = vmlaq_f32(_output1_tmn, _r0n, _k10n);
_output1_tmn = vmlaq_f32(_output1_tmn, _r1n, _k11n);
_output1_tmn = vmlaq_f32(_output1_tmn, _r2n, _k12n);
_output1_tmn = vmlaq_f32(_output1_tmn, _r3n, _k13n);

float32x4_t _k20n = vld1q_f32(k20);
k20 += 64;
float32x4_t _k21n = vld1q_f32(k20);
k20 += 64;
float32x4_t _k22n = vld1q_f32(k20);
k20 += 64;
float32x4_t _k23n = vld1q_f32(k20);
k20 += 64;

k20 -= 64*4;

_output2_tmn = vmlaq_f32(_output2_tmn, _r0n, _k20n);
_output2_tmn = vmlaq_f32(_output2_tmn, _r1n, _k21n);
_output2_tmn = vmlaq_f32(_output2_tmn, _r2n, _k22n);
_output2_tmn = vmlaq_f32(_output2_tmn, _r3n, _k23n);

float32x4_t _k30n = vld1q_f32(k30);
k30 += 64;
float32x4_t _k31n = vld1q_f32(k30);
k30 += 64;
float32x4_t _k32n = vld1q_f32(k30);
k30 += 64;
float32x4_t _k33n = vld1q_f32(k30);
k30 += 64;

k30 -= 64*4;

_output3_tmn = vmlaq_f32(_output3_tmn, _r0n, _k30n);
_output3_tmn = vmlaq_f32(_output3_tmn, _r1n, _k31n);
_output3_tmn = vmlaq_f32(_output3_tmn, _r2n, _k32n);
_output3_tmn = vmlaq_f32(_output3_tmn, _r3n, _k33n);

vst1q_f32(output0_tm, _output0_tmn);
vst1q_f32(output1_tm, _output1_tmn);
vst1q_f32(output2_tm, _output2_tmn);
vst1q_f32(output3_tm, _output3_tmn);

output0_tm += 4;
output1_tm += 4;
output2_tm += 4;
output3_tm += 4;

r0 += 4;
r1 += 4;
r2 += 4;
r3 += 4;

k00 += 4;
k10 += 4;
k20 += 4;
k30 += 4;
}
#else // __aarch64__
asm volatile(
"mov r4, #8 \n"

"pld [%0, #256] \n"
"vld1.f32 {d16-d19}, [%0 :128]\n"//q8 q9 = _output0_tm

"0: \n"

"pld [%4, #256] \n"
"vld1.f32 {d0-d3}, [%4 :128]! \n"//q0 q1 = _r0

"pld [%8, #256] \n"
"vld1.f32 {d20-d23}, [%8 :128]\n"//q10 q11 = _k00
"add %8, %8, #256 \n"

"vmla.f32 q8, q0, q10 \n"
"vmla.f32 q9, q1, q11 \n"

"pld [%1, #256] \n"
"vld1.f32 {d24-d27}, [%1 :128]\n"//q12 q13 = _output1_tm

"pld [%9, #256] \n"
"vld1.f32 {d28-d31}, [%9 :128]\n"//q14 q15 = _k10
"add %9, %9, #256 \n"

"vmla.f32 q12, q0, q14 \n"
"vmla.f32 q13, q1, q15 \n"

"pld [%5, #256] \n"
"vld1.f32 {d4-d7}, [%5 :128]! \n"//q2 q3 = _r1

"pld [%8, #256] \n"
"vld1.f32 {d20-d23}, [%8 :128]\n"//q10 q11 = _k01
"add %8, %8, #256 \n"

"vmla.f32 q8, q2, q10 \n"
"vmla.f32 q9, q3, q11 \n"

"pld [%9, #256] \n"
"vld1.f32 {d28-d31}, [%9 :128]\n"//q14 q15 = _k11
"add %9, %9, #256 \n"

"vmla.f32 q12, q2, q14 \n"
"vmla.f32 q13, q3, q15 \n"

"pld [%6, #256] \n"
"vld1.f32 {d8-d11}, [%6 :128]!\n"//q4 q5 = _r2

"pld [%8, #256] \n"
"vld1.f32 {d20-d23}, [%8 :128]\n"//q10 q11 = _k02
"add %8, %8, #256 \n"

"vmla.f32 q8, q4, q10 \n"
"vmla.f32 q9, q5, q11 \n"

"pld [%9, #256] \n"
"vld1.f32 {d28-d31}, [%9 :128]\n"//q14 q15 = _k12
"add %9, %9, #256 \n"

"vmla.f32 q12, q4, q14 \n"
"vmla.f32 q13, q5, q15 \n"

"pld [%7, #256] \n"
"vld1.f32 {d12-d15}, [%7 :128]!\n"//q6 q7 = _r3

"pld [%8, #256] \n"
"vld1.f32 {d20-d23}, [%8 :128]\n"//q10 q11 = _k03
"sub %8, %8, #736 \n"

"vmla.f32 q8, q6, q10 \n"
"vmla.f32 q9, q7, q11 \n"

"pld [%9, #256] \n"
"vld1.f32 {d28-d31}, [%9 :128]\n"//q14 q15 = _k13
"sub %9, %9, #736 \n"

"vmla.f32 q12, q6, q14 \n"
"vmla.f32 q13, q7, q15 \n"

"vst1.f32 {d16-d19}, [%0 :128]!\n"

"pld [%2, #256] \n"
"vld1.f32 {d16-d19}, [%2 :128]\n"//q8 q9 = _output2_tm

"pld [%10, #256] \n"
"vld1.f32 {d20-d23}, [%10 :128]\n"//q10 q11 = _k20
"add %10, %10, #256 \n"

"vmla.f32 q8, q0, q10 \n"
"vmla.f32 q9, q1, q11 \n"

"vst1.f32 {d24-d27}, [%1 :128]!\n"

"pld [%3, #256] \n"
"vld1.f32 {d24-d27}, [%3 :128]\n"//q12 q13 = _output3_tm

"pld [%11, #256] \n"
"vld1.f32 {d28-d31}, [%11 :128]\n"//q14 q15 = _k30
"add %11, %11, #256 \n"

"vmla.f32 q12, q0, q14 \n"
"vmla.f32 q13, q1, q15 \n"

"pld [%10, #256] \n"
"vld1.f32 {d20-d23}, [%10 :128]\n"//q10 q11 = _k21
"add %10, %10, #256 \n"

"vmla.f32 q8, q2, q10 \n"
"vmla.f32 q9, q3, q11 \n"

"pld [%11, #256] \n"
"vld1.f32 {d28-d31}, [%11 :128]\n"//q14 q15 = _k31
"add %11, %11, #256 \n"

"vmla.f32 q12, q2, q14 \n"
"vmla.f32 q13, q3, q15 \n"

"pld [%10, #256] \n"
"vld1.f32 {d20-d23}, [%10 :128]\n"//q10 q11 = _k22
"add %10, %10, #256 \n"

"vmla.f32 q8, q4, q10 \n"
"vmla.f32 q9, q5, q11 \n"

"pld [%11, #256] \n"
"vld1.f32 {d28-d31}, [%11 :128]\n"//q14 q15 = _k32
"add %11, %11, #256 \n"

"vmla.f32 q12, q4, q14 \n"
"vmla.f32 q13, q5, q15 \n"

"pld [%10, #256] \n"
"vld1.f32 {d20-d23}, [%10 :128]\n"//q10 q11 = _k23
"sub %10, %10, #736 \n"

"vmla.f32 q8, q6, q10 \n"
"vmla.f32 q9, q7, q11 \n"

"pld [%11, #256] \n"
"vld1.f32 {d28-d31}, [%11 :128]\n"//q14 q15 = _k33
"sub %11, %11, #736 \n"

"vmla.f32 q12, q6, q14 \n"
"vmla.f32 q13, q7, q15 \n"

"vst1.f32 {d16-d19}, [%2 :128]!\n"

"pld [%0, #256] \n"
"vld1.f32 {d16-d19}, [%0 :128]\n"//q8 q9 = _output0_tm

"subs r4, r4, #1 \n"

"vst1.f32 {d24-d27}, [%3 :128]!\n"

"bne 0b \n"

: "=r"(output0_tm), // %0
"=r"(output1_tm), // %1
"=r"(output2_tm), // %2
"=r"(output3_tm), // %3
"=r"(r0), // %4
"=r"(r1), // %5
"=r"(r2), // %6
"=r"(r3), // %7
"=r"(k00), // %8
"=r"(k10), // %9
"=r"(k20), // %10
"=r"(k30) // %11
: "0"(output0_tm),
"1"(output1_tm),
"2"(output2_tm),
"3"(output3_tm),
"4"(r0),
"5"(r1),
"6"(r2),
"7"(r3),
"8"(k00),
"9"(k10),
"10"(k20),
"11"(k30)
: "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"
);
#endif // __aarch64__

k00 -= 64;
k10 -= 64;
k20 -= 64;
k30 -= 64;
#else
for (int m=0; m<64; m++)
{
output0_tm[m] += r0[m] * k00[m];
k00 += 64;
output0_tm[m] += r1[m] * k00[m];
k00 += 64;
output0_tm[m] += r2[m] * k00[m];
k00 += 64;
output0_tm[m] += r3[m] * k00[m];
k00 += 64;

k00 -= 64 * 4;

output1_tm[m] += r0[m] * k10[m];
k10 += 64;
output1_tm[m] += r1[m] * k10[m];
k10 += 64;
output1_tm[m] += r2[m] * k10[m];
k10 += 64;
output1_tm[m] += r3[m] * k10[m];
k10 += 64;

k10 -= 64 * 4;

output2_tm[m] += r0[m] * k20[m];
k20 += 64;
output2_tm[m] += r1[m] * k20[m];
k20 += 64;
output2_tm[m] += r2[m] * k20[m];
k20 += 64;
output2_tm[m] += r3[m] * k20[m];
k20 += 64;

k20 -= 64 * 4;

output3_tm[m] += r0[m] * k30[m];
k30 += 64;
output3_tm[m] += r1[m] * k30[m];
k30 += 64;
output3_tm[m] += r2[m] * k30[m];
k30 += 64;
output3_tm[m] += r3[m] * k30[m];
k30 += 64;

k30 -= 64 * 4;
}

r0 += 64;
r1 += 64;
r2 += 64;
r3 += 64;
output0_tm += 64;
output1_tm += 64;
output2_tm += 64;
output3_tm += 64;
#endif // __ARM_NEON
}
}

for (; q<inch; q++)
{
const float* r0 = bottom_blob_tm.channel(q);

const float* k0 = kernel0_tm.row(q);
const float* k1 = kernel1_tm.row(q);
const float* k2 = kernel2_tm.row(q);
const float* k3 = kernel3_tm.row(q);

float* output0_tm = out0_tm;
float* output1_tm = out1_tm;
float* output2_tm = out2_tm;
float* output3_tm = out3_tm;

// tile
for (int i=0; i<h_tm/8 * w_tm/8; i++)
{
// TODO neon optimize
for (int m=0; m<64; m++)
{
output0_tm[m] += r0[m] * k0[m];
output1_tm[m] += r0[m] * k1[m];
output2_tm[m] += r0[m] * k2[m];
output3_tm[m] += r0[m] * k3[m];
}

r0 += 64;
output0_tm += 64;
output1_tm += 64;
output2_tm += 64;
output3_tm += 64;
}

}
}

#pragma omp parallel for
for (int p=remain_outch_start; p<outch; p++)
{
Mat out0_tm = top_blob_tm.channel(p);
const Mat kernel0_tm = kernel_tm.channel(p);
@@ -1421,6 +1942,7 @@ static void conv3x3s1_winograd64_neon2(const Mat& bottom_blob, Mat& top_blob, co
}

}
bottom_blob_bordered = Mat();
// END transform input

// BEGIN dot


Loading…
Cancel
Save