Browse Source

unroll outch 2 for conv3x3s1 pack1to4

tags/20191113
nihui 6 years ago
parent
commit
2f8b31c3b4
1 changed files with 352 additions and 1 deletions
  1. +352
    -1
      src/layer/arm/convolution_3x3_pack1to4.h

+ 352
- 1
src/layer/arm/convolution_3x3_pack1to4.h View File

@@ -21,8 +21,359 @@ static void conv3x3s1_pack1to4_neon(const Mat& bottom_blob, Mat& top_blob, const

const float* bias = _bias;

int nn_outch = 0;
int remain_outch_start = 0;

#if __ARM_NEON && __aarch64__
nn_outch = outch >> 1;
remain_outch_start = nn_outch << 1;

#pragma omp parallel for num_threads(opt.num_threads)
for (int p=0; p<outch; p++)
for (int pp=0; pp<nn_outch; pp++)
{
int p = pp * 2;

Mat out0 = top_blob.channel(p);
Mat out1 = top_blob.channel(p+1);

float32x4_t _bias0 = bias ? vld1q_f32((const float*)bias + p * 4) : vdupq_n_f32(0.f);
float32x4_t _bias1 = bias ? vld1q_f32((const float*)bias + (p+1) * 4) : vdupq_n_f32(0.f);
out0.fill(_bias0);
out1.fill(_bias1);

const float* k0 = kernel.channel(p);
const float* k1 = kernel.channel(p+1);

for (int q=0; q<inch; q++)
{
float* outptr0 = out0;
float* outptr1 = out1;

const Mat img0 = bottom_blob.channel(q);

const float* r0 = img0.row(0);
const float* r1 = img0.row(1);
const float* r2 = img0.row(2);

float32x4_t _k00_0 = vld1q_f32(k0);
float32x4_t _k01_0 = vld1q_f32(k0+4);
float32x4_t _k02_0 = vld1q_f32(k0+8);
float32x4_t _k10_0 = vld1q_f32(k0+12);
float32x4_t _k11_0 = vld1q_f32(k0+16);
float32x4_t _k12_0 = vld1q_f32(k0+20);
float32x4_t _k20_0 = vld1q_f32(k0+24);
float32x4_t _k21_0 = vld1q_f32(k0+28);
float32x4_t _k22_0 = vld1q_f32(k0+32);

float32x4_t _k00_1 = vld1q_f32(k1);
float32x4_t _k01_1 = vld1q_f32(k1+4);
float32x4_t _k02_1 = vld1q_f32(k1+8);
float32x4_t _k10_1 = vld1q_f32(k1+12);
float32x4_t _k11_1 = vld1q_f32(k1+16);
float32x4_t _k12_1 = vld1q_f32(k1+20);
float32x4_t _k20_1 = vld1q_f32(k1+24);
float32x4_t _k21_1 = vld1q_f32(k1+28);
float32x4_t _k22_1 = vld1q_f32(k1+32);

int i = 0;

for (; i < outh; i++)
{
int j = 0;

for (; j+3<outw; j+=4)
{
asm volatile(
"prfm pldl1keep, [%0, #512] \n"
"ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0] \n"

"prfm pldl1keep, [%1, #512] \n"
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%1] \n"

"prfm pldl1keep, [%2, #128] \n"
"ld1 {v0.4s}, [%2], #16 \n"

"ld1 {v1.2s}, [%2] \n"

"fmla v24.4s, %10.4s, v0.s[0] \n"
"fmla v25.4s, %10.4s, v0.s[1] \n"
"fmla v26.4s, %10.4s, v0.s[2] \n"
"fmla v27.4s, %10.4s, v0.s[3] \n"
"fmla v28.4s, %19.4s, v0.s[0] \n"
"fmla v29.4s, %19.4s, v0.s[1] \n"
"fmla v30.4s, %19.4s, v0.s[2] \n"
"fmla v31.4s, %19.4s, v0.s[3] \n"

"fmla v24.4s, %11.4s, v0.s[1] \n"
"fmla v25.4s, %11.4s, v0.s[2] \n"
"fmla v26.4s, %11.4s, v0.s[3] \n"
"fmla v27.4s, %11.4s, v1.s[0] \n"
"fmla v28.4s, %20.4s, v0.s[1] \n"
"fmla v29.4s, %20.4s, v0.s[2] \n"
"fmla v30.4s, %20.4s, v0.s[3] \n"
"fmla v31.4s, %20.4s, v1.s[0] \n"

"prfm pldl1keep, [%3, #128] \n"
"ld1 {v2.4s}, [%3], #16 \n"

"ld1 {v3.2s}, [%3] \n"

"fmla v24.4s, %12.4s, v0.s[2] \n"
"fmla v25.4s, %12.4s, v0.s[3] \n"
"fmla v26.4s, %12.4s, v1.s[0] \n"
"fmla v27.4s, %12.4s, v1.s[1] \n"
"fmla v28.4s, %21.4s, v0.s[2] \n"
"fmla v29.4s, %21.4s, v0.s[3] \n"
"fmla v30.4s, %21.4s, v1.s[0] \n"
"fmla v31.4s, %21.4s, v1.s[1] \n"

"fmla v24.4s, %13.4s, v2.s[0] \n"
"fmla v25.4s, %13.4s, v2.s[1] \n"
"fmla v26.4s, %13.4s, v2.s[2] \n"
"fmla v27.4s, %13.4s, v2.s[3] \n"
"fmla v28.4s, %22.4s, v2.s[0] \n"
"fmla v29.4s, %22.4s, v2.s[1] \n"
"fmla v30.4s, %22.4s, v2.s[2] \n"
"fmla v31.4s, %22.4s, v2.s[3] \n"

"fmla v24.4s, %14.4s, v2.s[1] \n"
"fmla v25.4s, %14.4s, v2.s[2] \n"
"fmla v26.4s, %14.4s, v2.s[3] \n"
"fmla v27.4s, %14.4s, v3.s[0] \n"
"fmla v28.4s, %23.4s, v2.s[1] \n"
"fmla v29.4s, %23.4s, v2.s[2] \n"
"fmla v30.4s, %23.4s, v2.s[3] \n"
"fmla v31.4s, %23.4s, v3.s[0] \n"

"prfm pldl1keep, [%4, #128] \n"
"ld1 {v0.4s}, [%4], #16 \n"

"ld1 {v1.2s}, [%4] \n"

"fmla v24.4s, %15.4s, v2.s[2] \n"
"fmla v25.4s, %15.4s, v2.s[3] \n"
"fmla v26.4s, %15.4s, v3.s[0] \n"
"fmla v27.4s, %15.4s, v3.s[1] \n"
"fmla v28.4s, %24.4s, v2.s[2] \n"
"fmla v29.4s, %24.4s, v2.s[3] \n"
"fmla v30.4s, %24.4s, v3.s[0] \n"
"fmla v31.4s, %24.4s, v3.s[1] \n"

"fmla v24.4s, %16.4s, v0.s[0] \n"
"fmla v25.4s, %16.4s, v0.s[1] \n"
"fmla v26.4s, %16.4s, v0.s[2] \n"
"fmla v27.4s, %16.4s, v0.s[3] \n"
"fmla v28.4s, %25.4s, v0.s[0] \n"
"fmla v29.4s, %25.4s, v0.s[1] \n"
"fmla v30.4s, %25.4s, v0.s[2] \n"
"fmla v31.4s, %25.4s, v0.s[3] \n"

"fmla v24.4s, %17.4s, v0.s[1] \n"
"fmla v25.4s, %17.4s, v0.s[2] \n"
"fmla v26.4s, %17.4s, v0.s[3] \n"
"fmla v27.4s, %17.4s, v1.s[0] \n"
"fmla v28.4s, %26.4s, v0.s[1] \n"
"fmla v29.4s, %26.4s, v0.s[2] \n"
"fmla v30.4s, %26.4s, v0.s[3] \n"
"fmla v31.4s, %26.4s, v1.s[0] \n"

"fmla v24.4s, %18.4s, v0.s[2] \n"
"fmla v25.4s, %18.4s, v0.s[3] \n"
"fmla v26.4s, %18.4s, v1.s[0] \n"
"fmla v27.4s, %18.4s, v1.s[1] \n"
"fmla v28.4s, %27.4s, v0.s[2] \n"
"fmla v29.4s, %27.4s, v0.s[3] \n"
"fmla v30.4s, %27.4s, v1.s[0] \n"
"fmla v31.4s, %27.4s, v1.s[1] \n"

"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n"
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%1], #64 \n"

: "=r"(outptr0), // %0
"=r"(outptr1), // %1
"=r"(r0), // %2
"=r"(r1), // %3
"=r"(r2) // %4
: "0"(outptr0),
"1"(outptr1),
"2"(r0),
"3"(r1),
"4"(r2),
"w"(_k00_0), // %10
"w"(_k01_0), // %11
"w"(_k02_0), // %12
"w"(_k10_0), // %13
"w"(_k11_0), // %14
"w"(_k12_0), // %15
"w"(_k20_0), // %16
"w"(_k21_0), // %17
"w"(_k22_0), // %18
"w"(_k00_1), // %19
"w"(_k01_1), // %20
"w"(_k02_1), // %21
"w"(_k10_1), // %22
"w"(_k11_1), // %23
"w"(_k12_1), // %24
"w"(_k20_1), // %25
"w"(_k21_1), // %26
"w"(_k22_1) // %27
: "memory", "v0", "v1", "v2", "v3", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
);
}
for (; j+1<outw; j+=2)
{
asm volatile(
"prfm pldl1keep, [%0, #256] \n"
"ld1 {v24.4s, v25.4s}, [%0] \n"

"prfm pldl1keep, [%1, #256] \n"
"ld1 {v26.4s, v27.4s}, [%1] \n"

"prfm pldl1keep, [%2, #128] \n"
"ld1 {v0.4s}, [%2] \n"
"add %2, %2, #8 \n"

"fmla v24.4s, %10.4s, v0.s[0] \n"
"fmla v25.4s, %10.4s, v0.s[1] \n"
"fmla v26.4s, %19.4s, v0.s[0] \n"
"fmla v27.4s, %19.4s, v0.s[1] \n"

"fmla v24.4s, %11.4s, v0.s[1] \n"
"fmla v25.4s, %11.4s, v0.s[2] \n"
"fmla v26.4s, %20.4s, v0.s[1] \n"
"fmla v27.4s, %20.4s, v0.s[2] \n"

"prfm pldl1keep, [%3, #128] \n"
"ld1 {v1.4s}, [%3] \n"

"fmla v24.4s, %12.4s, v0.s[2] \n"
"fmla v25.4s, %12.4s, v0.s[3] \n"
"fmla v26.4s, %21.4s, v0.s[2] \n"
"fmla v27.4s, %21.4s, v0.s[3] \n"

"add %3, %3, #8 \n"

"fmla v24.4s, %13.4s, v1.s[0] \n"
"fmla v25.4s, %13.4s, v1.s[1] \n"
"fmla v26.4s, %22.4s, v1.s[0] \n"
"fmla v27.4s, %22.4s, v1.s[1] \n"

"fmla v24.4s, %14.4s, v1.s[1] \n"
"fmla v25.4s, %14.4s, v1.s[2] \n"
"fmla v26.4s, %23.4s, v1.s[1] \n"
"fmla v27.4s, %23.4s, v1.s[2] \n"

"prfm pldl1keep, [%4, #128] \n"
"ld1 {v0.4s}, [%4] \n"

"fmla v24.4s, %15.4s, v1.s[2] \n"
"fmla v25.4s, %15.4s, v1.s[3] \n"
"fmla v26.4s, %24.4s, v1.s[2] \n"
"fmla v27.4s, %24.4s, v1.s[3] \n"

"add %4, %4, #8 \n"

"fmla v24.4s, %16.4s, v0.s[0] \n"
"fmla v25.4s, %16.4s, v0.s[1] \n"
"fmla v26.4s, %25.4s, v0.s[0] \n"
"fmla v27.4s, %25.4s, v0.s[1] \n"

"fmla v24.4s, %17.4s, v0.s[1] \n"
"fmla v25.4s, %17.4s, v0.s[2] \n"
"fmla v26.4s, %26.4s, v0.s[1] \n"
"fmla v27.4s, %26.4s, v0.s[2] \n"

"fmla v24.4s, %18.4s, v0.s[2] \n"
"fmla v25.4s, %18.4s, v0.s[3] \n"
"fmla v26.4s, %27.4s, v0.s[2] \n"
"fmla v27.4s, %27.4s, v0.s[3] \n"

"st1 {v24.4s, v25.4s}, [%0], #32 \n"
"st1 {v26.4s, v27.4s}, [%1], #32 \n"

: "=r"(outptr0), // %0
"=r"(outptr1), // %1
"=r"(r0), // %2
"=r"(r1), // %3
"=r"(r2) // %4
: "0"(outptr0),
"1"(outptr1),
"2"(r0),
"3"(r1),
"4"(r2),
"w"(_k00_0), // %10
"w"(_k01_0), // %11
"w"(_k02_0), // %12
"w"(_k10_0), // %13
"w"(_k11_0), // %14
"w"(_k12_0), // %15
"w"(_k20_0), // %16
"w"(_k21_0), // %17
"w"(_k22_0), // %18
"w"(_k00_1), // %19
"w"(_k01_1), // %20
"w"(_k02_1), // %21
"w"(_k10_1), // %22
"w"(_k11_1), // %23
"w"(_k12_1), // %24
"w"(_k20_1), // %25
"w"(_k21_1), // %26
"w"(_k22_1) // %27
: "memory", "v0", "v1", "v24", "v25", "v26", "v27"
);
}
for (; j<outw; j++)
{
float32x4_t _sum00 = vld1q_f32(outptr0);
float32x4_t _sum10 = vld1q_f32(outptr1);

float32x4_t _r0 = vld1q_f32(r0);
float32x4_t _r1 = vld1q_f32(r1);
float32x4_t _r2 = vld1q_f32(r2);

_sum00 = vfmaq_laneq_f32(_sum00, _k00_0, _r0, 0);
_sum00 = vfmaq_laneq_f32(_sum00, _k01_0, _r0, 1);
_sum00 = vfmaq_laneq_f32(_sum00, _k02_0, _r0, 2);
_sum00 = vfmaq_laneq_f32(_sum00, _k10_0, _r1, 0);
_sum00 = vfmaq_laneq_f32(_sum00, _k11_0, _r1, 1);
_sum00 = vfmaq_laneq_f32(_sum00, _k12_0, _r1, 2);
_sum00 = vfmaq_laneq_f32(_sum00, _k20_0, _r2, 0);
_sum00 = vfmaq_laneq_f32(_sum00, _k21_0, _r2, 1);
_sum00 = vfmaq_laneq_f32(_sum00, _k22_0, _r2, 2);

_sum10 = vfmaq_laneq_f32(_sum10, _k00_1, _r0, 0);
_sum10 = vfmaq_laneq_f32(_sum10, _k01_1, _r0, 1);
_sum10 = vfmaq_laneq_f32(_sum10, _k02_1, _r0, 2);
_sum10 = vfmaq_laneq_f32(_sum10, _k10_1, _r1, 0);
_sum10 = vfmaq_laneq_f32(_sum10, _k11_1, _r1, 1);
_sum10 = vfmaq_laneq_f32(_sum10, _k12_1, _r1, 2);
_sum10 = vfmaq_laneq_f32(_sum10, _k20_1, _r2, 0);
_sum10 = vfmaq_laneq_f32(_sum10, _k21_1, _r2, 1);
_sum10 = vfmaq_laneq_f32(_sum10, _k22_1, _r2, 2);

vst1q_f32(outptr0, _sum00);
vst1q_f32(outptr1, _sum10);

r0 += 1;
r1 += 1;
r2 += 1;
outptr0 += 4;
outptr1 += 4;
}

r0 += 2;
r1 += 2;
r2 += 2;
}

k0 += 9*4;
k1 += 9*4;
}
}
#endif // __ARM_NEON && __aarch64__

#pragma omp parallel for num_threads(opt.num_threads)
for (int p=remain_outch_start; p<outch; p++)
{
Mat out0 = top_blob.channel(p);



Loading…
Cancel
Save