Browse Source

arm neon optimize for winograd input output transform, about 4%~22% faster

tags/20180427
nihui 8 years ago
parent
commit
d0bcda70ca
1 changed files with 414 additions and 39 deletions
  1. +414
    -39
      src/layer/arm/convolution_3x3.h

+ 414
- 39
src/layer/arm/convolution_3x3.h View File

@@ -5320,6 +5320,15 @@ static void conv3x3s1_winograd64_neon4(const Mat& bottom_blob, Mat& top_blob, co
// 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5)
// 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5)

#if __ARM_NEON
const float coeff[8] = {
0.25f, 0.5f, -1.25f, 2.f,
-2.5f, 4.f, 4.25f, 5.25f
};
float32x4_t _coeff0 = vld1q_f32(coeff);
float32x4_t _coeff1 = vld1q_f32(coeff+4);
#endif // __ARM_NEON

#pragma omp parallel for
for (int q = 0; q<inch; q++)
{
@@ -5333,23 +5342,230 @@ static void conv3x3s1_winograd64_neon4(const Mat& bottom_blob, Mat& top_blob, co
{
for (int j=0; j<w_tm/8; j++)
{
#if __ARM_NEON
const float* r0 = img0.row(i * 6) + j * 6;
const float* r1 = r0 + w;
const float* r2 = r0 + w*2;
const float* r3 = r0 + w*3;

for (int m=0; m+3<8; m+=4)
{
float32x4_t _r0_0123 = vld1q_f32(r0);
float32x4_t _r0_4567 = vld1q_f32(r0+4);
float32x4_t _r1_0123 = vld1q_f32(r1);
float32x4_t _r1_4567 = vld1q_f32(r1+4);
float32x4_t _r2_0123 = vld1q_f32(r2);
float32x4_t _r2_4567 = vld1q_f32(r2+4);
float32x4_t _r3_0123 = vld1q_f32(r3);
float32x4_t _r3_4567 = vld1q_f32(r3+4);

float32x4x2_t _r01_00221133 = vtrnq_f32(_r0_0123, _r1_0123);
float32x4x2_t _r01_44665577 = vtrnq_f32(_r0_4567, _r1_4567);
float32x4x2_t _r23_00221133 = vtrnq_f32(_r2_0123, _r3_0123);
float32x4x2_t _r23_44665577 = vtrnq_f32(_r2_4567, _r3_4567);

// no vswp intrinsic :(
float32x4_t _r_00 = vcombine_f32(vget_low_f32(_r01_00221133.val[0]), vget_low_f32(_r23_00221133.val[0]));
float32x4_t _r_11 = vcombine_f32(vget_low_f32(_r01_00221133.val[1]), vget_low_f32(_r23_00221133.val[1]));
float32x4_t _r_22 = vcombine_f32(vget_high_f32(_r01_00221133.val[0]), vget_high_f32(_r23_00221133.val[0]));
float32x4_t _r_33 = vcombine_f32(vget_high_f32(_r01_00221133.val[1]), vget_high_f32(_r23_00221133.val[1]));
float32x4_t _r_44 = vcombine_f32(vget_low_f32(_r01_44665577.val[0]), vget_low_f32(_r23_44665577.val[0]));
float32x4_t _r_55 = vcombine_f32(vget_low_f32(_r01_44665577.val[1]), vget_low_f32(_r23_44665577.val[1]));
float32x4_t _r_66 = vcombine_f32(vget_high_f32(_r01_44665577.val[0]), vget_high_f32(_r23_44665577.val[0]));
float32x4_t _r_77 = vcombine_f32(vget_high_f32(_r01_44665577.val[1]), vget_high_f32(_r23_44665577.val[1]));

float32x4_t _r_0_m_6 = vsubq_f32(_r_00, _r_66);
float32x4_t _r_7_m_1 = vsubq_f32(_r_77, _r_11);

float32x4_t _r_4_m_2 = vsubq_f32(_r_44, _r_22);
float32x4_t _r_3_m_5 = vsubq_f32(_r_33, _r_55);

float32x4_t _tmp0 = vmlaq_lane_f32(_r_0_m_6, _r_4_m_2, vget_high_f32(_coeff1), 1);
float32x4_t _tmp7 = vmlaq_lane_f32(_r_7_m_1, _r_3_m_5, vget_high_f32(_coeff1), 1);

vst1q_f32(&tmp[0][m], _tmp0);
vst1q_f32(&tmp[7][m], _tmp7);

float32x4_t _r_2_a_6 = vaddq_f32(_r_22, _r_66);
float32x4_t _r_1_a_5 = vaddq_f32(_r_11, _r_55);

float32x4_t _tmp12a = vmlsq_lane_f32(_r_2_a_6, _r_44, vget_high_f32(_coeff1), 0);
float32x4_t _tmp12b = vmlsq_lane_f32(_r_1_a_5, _r_33, vget_high_f32(_coeff1), 0);

float32x4_t _tmp1 = vaddq_f32(_tmp12a, _tmp12b);
float32x4_t _tmp2 = vsubq_f32(_tmp12a, _tmp12b);

vst1q_f32(&tmp[1][m], _tmp1);
vst1q_f32(&tmp[2][m], _tmp2);

float32x4_t _r_4_x_c = vmulq_lane_f32(_r_44, vget_high_f32(_coeff0), 0);
float32x4_t _r_3_x_c = vmulq_lane_f32(_r_33, vget_low_f32(_coeff1), 0);

float32x4_t _tmp34a = vaddq_f32(_r_66, _r_4_x_c);
_tmp34a = vmlaq_lane_f32(_tmp34a, _r_22, vget_low_f32(_coeff0), 0);

float32x4_t _tmp34b = vmlaq_lane_f32(_r_3_x_c, _r_11, vget_low_f32(_coeff0), 1);
_tmp34b = vmlaq_lane_f32(_tmp34b, _r_55, vget_high_f32(_coeff0), 1);

float32x4_t _tmp3 = vaddq_f32(_tmp34a, _tmp34b);
float32x4_t _tmp4 = vsubq_f32(_tmp34a, _tmp34b);

vst1q_f32(&tmp[3][m], _tmp3);
vst1q_f32(&tmp[4][m], _tmp4);

// reuse r04 * 1.25
// reuse r03 * 2.5
float32x4_t _r_2_a_4c = vaddq_f32(_r_22, _r_4_x_c);
float32x4_t _tmp56a = vmlaq_lane_f32(_r_66, _r_2_a_4c, vget_low_f32(_coeff1), 1);
float32x4_t _tmp56b = vmlaq_lane_f32(_r_3_x_c, _r_11, vget_high_f32(_coeff0), 1);
_tmp56b = vmlaq_lane_f32(_tmp56b, _r_55, vget_low_f32(_coeff0), 1);

float32x4_t _tmp5 = vaddq_f32(_tmp56a, _tmp56b);
float32x4_t _tmp6 = vsubq_f32(_tmp56a, _tmp56b);

vst1q_f32(&tmp[5][m], _tmp5);
vst1q_f32(&tmp[6][m], _tmp6);

r0 += w*4;
r1 += w*4;
r2 += w*4;
r3 += w*4;
}

const float* t0 = tmp[0];
const float* t1 = tmp[1];
const float* t2 = tmp[2];
const float* t3 = tmp[3];

float* r0_tm0_0 = img0_tm.row(i * w_tm/8 + j);
float* r0_tm0_4 = img0_tm.row(i * w_tm/8 + j + tiles);
float* r0_tm1_0 = img0_tm.row(i * w_tm/8 + j + tiles * 2);
float* r0_tm1_4 = img0_tm.row(i * w_tm/8 + j + tiles * 3);
float* r0_tm2_0 = img0_tm.row(i * w_tm/8 + j + tiles * 4);
float* r0_tm2_4 = img0_tm.row(i * w_tm/8 + j + tiles * 5);
float* r0_tm3_0 = img0_tm.row(i * w_tm/8 + j + tiles * 6);
float* r0_tm3_4 = img0_tm.row(i * w_tm/8 + j + tiles * 7);
float* r0_tm4_0 = img0_tm.row(i * w_tm/8 + j + tiles * 8);
float* r0_tm4_4 = img0_tm.row(i * w_tm/8 + j + tiles * 9);
float* r0_tm5_0 = img0_tm.row(i * w_tm/8 + j + tiles * 10);
float* r0_tm5_4 = img0_tm.row(i * w_tm/8 + j + tiles * 11);
float* r0_tm6_0 = img0_tm.row(i * w_tm/8 + j + tiles * 12);
float* r0_tm6_4 = img0_tm.row(i * w_tm/8 + j + tiles * 13);
float* r0_tm7_0 = img0_tm.row(i * w_tm/8 + j + tiles * 14);
float* r0_tm7_4 = img0_tm.row(i * w_tm/8 + j + tiles * 15);
float* r0_tm1_0 = img0_tm.row(i * w_tm/8 + j + tiles*2);
float* r0_tm1_4 = img0_tm.row(i * w_tm/8 + j + tiles*3);
float* r0_tm2_0 = img0_tm.row(i * w_tm/8 + j + tiles*4);
float* r0_tm2_4 = img0_tm.row(i * w_tm/8 + j + tiles*5);
float* r0_tm3_0 = img0_tm.row(i * w_tm/8 + j + tiles*6);
float* r0_tm3_4 = img0_tm.row(i * w_tm/8 + j + tiles*7);

for (int m=0; m+3<8; m+=4)
{
float32x4_t _t0_0123 = vld1q_f32(t0);
float32x4_t _t0_4567 = vld1q_f32(t0+4);
float32x4_t _t1_0123 = vld1q_f32(t1);
float32x4_t _t1_4567 = vld1q_f32(t1+4);
float32x4_t _t2_0123 = vld1q_f32(t2);
float32x4_t _t2_4567 = vld1q_f32(t2+4);
float32x4_t _t3_0123 = vld1q_f32(t3);
float32x4_t _t3_4567 = vld1q_f32(t3+4);

float32x4x2_t _t01_00221133 = vtrnq_f32(_t0_0123, _t1_0123);
float32x4x2_t _t01_44665577 = vtrnq_f32(_t0_4567, _t1_4567);
float32x4x2_t _t23_00221133 = vtrnq_f32(_t2_0123, _t3_0123);
float32x4x2_t _t23_44665577 = vtrnq_f32(_t2_4567, _t3_4567);

// no vswp intrinsic :(
float32x4_t _t_00 = vcombine_f32(vget_low_f32(_t01_00221133.val[0]), vget_low_f32(_t23_00221133.val[0]));
float32x4_t _t_11 = vcombine_f32(vget_low_f32(_t01_00221133.val[1]), vget_low_f32(_t23_00221133.val[1]));
float32x4_t _t_22 = vcombine_f32(vget_high_f32(_t01_00221133.val[0]), vget_high_f32(_t23_00221133.val[0]));
float32x4_t _t_33 = vcombine_f32(vget_high_f32(_t01_00221133.val[1]), vget_high_f32(_t23_00221133.val[1]));
float32x4_t _t_44 = vcombine_f32(vget_low_f32(_t01_44665577.val[0]), vget_low_f32(_t23_44665577.val[0]));
float32x4_t _t_55 = vcombine_f32(vget_low_f32(_t01_44665577.val[1]), vget_low_f32(_t23_44665577.val[1]));
float32x4_t _t_66 = vcombine_f32(vget_high_f32(_t01_44665577.val[0]), vget_high_f32(_t23_44665577.val[0]));
float32x4_t _t_77 = vcombine_f32(vget_high_f32(_t01_44665577.val[1]), vget_high_f32(_t23_44665577.val[1]));

float32x4_t _t_0_m_6 = vsubq_f32(_t_00, _t_66);
float32x4_t _t_7_m_1 = vsubq_f32(_t_77, _t_11);

float32x4_t _t_4_m_2 = vsubq_f32(_t_44, _t_22);
float32x4_t _t_3_m_5 = vsubq_f32(_t_33, _t_55);

float32x4_t _r0_tm_0_0 = vmlaq_lane_f32(_t_0_m_6, _t_4_m_2, vget_high_f32(_coeff1), 1);
float32x4_t _r0_tm_4_3 = vmlaq_lane_f32(_t_7_m_1, _t_3_m_5, vget_high_f32(_coeff1), 1);

r0_tm0_0[0] = vgetq_lane_f32(_r0_tm_0_0, 0);
r0_tm1_0[0] = vgetq_lane_f32(_r0_tm_0_0, 1);
r0_tm2_0[0] = vgetq_lane_f32(_r0_tm_0_0, 2);
r0_tm3_0[0] = vgetq_lane_f32(_r0_tm_0_0, 3);

r0_tm0_4[3] = vgetq_lane_f32(_r0_tm_4_3, 0);
r0_tm1_4[3] = vgetq_lane_f32(_r0_tm_4_3, 1);
r0_tm2_4[3] = vgetq_lane_f32(_r0_tm_4_3, 2);
r0_tm3_4[3] = vgetq_lane_f32(_r0_tm_4_3, 3);

float32x4_t _t_2_m_6 = vaddq_f32(_t_22, _t_66);
float32x4_t _t_1_m_5 = vaddq_f32(_t_11, _t_55);

float32x4_t _tmp12a = vmlsq_lane_f32(_t_2_m_6, _t_44, vget_high_f32(_coeff1), 0);
float32x4_t _tmp12b = vmlsq_lane_f32(_t_1_m_5, _t_33, vget_high_f32(_coeff1), 0);

float32x4_t _r0_tm_0_1 = vaddq_f32(_tmp12a, _tmp12b);
float32x4_t _r0_tm_0_2 = vsubq_f32(_tmp12a, _tmp12b);

r0_tm0_0[1] = vgetq_lane_f32(_r0_tm_0_1, 0);
r0_tm1_0[1] = vgetq_lane_f32(_r0_tm_0_1, 1);
r0_tm2_0[1] = vgetq_lane_f32(_r0_tm_0_1, 2);
r0_tm3_0[1] = vgetq_lane_f32(_r0_tm_0_1, 3);

r0_tm0_0[2] = vgetq_lane_f32(_r0_tm_0_2, 0);
r0_tm1_0[2] = vgetq_lane_f32(_r0_tm_0_2, 1);
r0_tm2_0[2] = vgetq_lane_f32(_r0_tm_0_2, 2);
r0_tm3_0[2] = vgetq_lane_f32(_r0_tm_0_2, 3);

float32x4_t _t_4_x_c = vmulq_lane_f32(_t_44, vget_high_f32(_coeff0), 0);
float32x4_t _t_3_x_c = vmulq_lane_f32(_t_33, vget_low_f32(_coeff1), 0);

float32x4_t _tmp34a = vaddq_f32(_t_66, _t_4_x_c);
_tmp34a = vmlaq_lane_f32(_tmp34a, _t_22, vget_low_f32(_coeff0), 0);

float32x4_t _tmp34b = vmlaq_lane_f32(_t_3_x_c, _t_11, vget_low_f32(_coeff0), 1);
_tmp34b = vmlaq_lane_f32(_tmp34b, _t_55, vget_high_f32(_coeff0), 1);

float32x4_t _r0_tm_0_3 = vaddq_f32(_tmp34a, _tmp34b);
float32x4_t _r0_tm_4_0 = vsubq_f32(_tmp34a, _tmp34b);

r0_tm0_0[3] = vgetq_lane_f32(_r0_tm_0_3, 0);
r0_tm1_0[3] = vgetq_lane_f32(_r0_tm_0_3, 1);
r0_tm2_0[3] = vgetq_lane_f32(_r0_tm_0_3, 2);
r0_tm3_0[3] = vgetq_lane_f32(_r0_tm_0_3, 3);

r0_tm0_4[0] = vgetq_lane_f32(_r0_tm_4_0, 0);
r0_tm1_4[0] = vgetq_lane_f32(_r0_tm_4_0, 1);
r0_tm2_4[0] = vgetq_lane_f32(_r0_tm_4_0, 2);
r0_tm3_4[0] = vgetq_lane_f32(_r0_tm_4_0, 3);

float32x4_t _t_2_a_4c = vaddq_f32(_t_22, _t_4_x_c);
float32x4_t _tmp56a = vmlaq_lane_f32(_t_66, _t_2_a_4c, vget_low_f32(_coeff1), 1);
float32x4_t _tmp56b = vmlaq_lane_f32(_t_3_x_c, _t_11, vget_high_f32(_coeff0), 1);
_tmp56b = vmlaq_lane_f32(_tmp56b, _t_55, vget_low_f32(_coeff0), 1);

float32x4_t _r0_tm_4_1 = vaddq_f32(_tmp56a, _tmp56b);
float32x4_t _r0_tm_4_2 = vsubq_f32(_tmp56a, _tmp56b);

r0_tm0_4[1] = vgetq_lane_f32(_r0_tm_4_1, 0);
r0_tm1_4[1] = vgetq_lane_f32(_r0_tm_4_1, 1);
r0_tm2_4[1] = vgetq_lane_f32(_r0_tm_4_1, 2);
r0_tm3_4[1] = vgetq_lane_f32(_r0_tm_4_1, 3);

r0_tm0_4[2] = vgetq_lane_f32(_r0_tm_4_2, 0);
r0_tm1_4[2] = vgetq_lane_f32(_r0_tm_4_2, 1);
r0_tm2_4[2] = vgetq_lane_f32(_r0_tm_4_2, 2);
r0_tm3_4[2] = vgetq_lane_f32(_r0_tm_4_2, 3);

t0 += 8*4;
t1 += 8*4;
t2 += 8*4;
t3 += 8*4;

r0_tm0_0 += img0_tm.w*tiles*2*4;
r0_tm0_4 += img0_tm.w*tiles*2*4;
r0_tm1_0 += img0_tm.w*tiles*2*4;
r0_tm1_4 += img0_tm.w*tiles*2*4;
r0_tm2_0 += img0_tm.w*tiles*2*4;
r0_tm2_4 += img0_tm.w*tiles*2*4;
r0_tm3_0 += img0_tm.w*tiles*2*4;
r0_tm3_4 += img0_tm.w*tiles*2*4;
}
#else
const float* r0 = img0.row(i * 6) + j * 6;

for (int m=0; m<8; m++)
{
@@ -5377,16 +5593,13 @@ static void conv3x3s1_winograd64_neon4(const Mat& bottom_blob, Mat& top_blob, co
r0 += w;
}

float* r0_tms_0[8] = { r0_tm0_0, r0_tm1_0, r0_tm2_0, r0_tm3_0, r0_tm4_0, r0_tm5_0, r0_tm6_0, r0_tm7_0 };
float* r0_tms_4[8] = { r0_tm0_4, r0_tm1_4, r0_tm2_4, r0_tm3_4, r0_tm4_4, r0_tm5_4, r0_tm6_4, r0_tm7_4 };
float* r0_tm_0 = img0_tm.row(i * w_tm/8 + j);
float* r0_tm_4 = img0_tm.row(i * w_tm/8 + j + tiles);

for (int m=0; m<8; m++)
{
const float* tmp0 = tmp[m];

float* r0_tm_0 = r0_tms_0[m];
float* r0_tm_4 = r0_tms_4[m];

r0_tm_0[0] = tmp0[0] - tmp0[6] + (tmp0[4] - tmp0[2]) * 5.25f;
r0_tm_4[3] = tmp0[7] - tmp0[1] + (tmp0[3] - tmp0[5]) * 5.25f;

@@ -5407,11 +5620,14 @@ static void conv3x3s1_winograd64_neon4(const Mat& bottom_blob, Mat& top_blob, co

r0_tm_4[1] = tmp56a + tmp56b;
r0_tm_4[2] = tmp56a - tmp56b;

r0_tm_0 += img0_tm.w * tiles * 2;
r0_tm_4 += img0_tm.w * tiles * 2;
}
#endif // __ARM_NEON
}
}
}

}
bottom_blob_bordered = Mat();
// END transform input
@@ -6702,6 +6918,11 @@ static void conv3x3s1_winograd64_neon4(const Mat& bottom_blob, Mat& top_blob, co
// 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2
// 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6)

#if __ARM_NEON
const float coeff[4] = { 4.f, 8.f, 16.f, 32.f };
float32x4_t _coeff = vld1q_f32(coeff);
#endif // __ARM_NEON

int w_tm = outw / 6 * 8;
int h_tm = outh / 6 * 8;
const int tiles = w_tm/8 * h_tm/8;
@@ -6713,6 +6934,9 @@ static void conv3x3s1_winograd64_neon4(const Mat& bottom_blob, Mat& top_blob, co
Mat out0 = top_blob_bordered.channel(p);

const float bias0 = bias ? bias[p] : 0.f;
#if __ARM_NEON
float32x2_t _bias0 = vdup_n_f32(bias0);
#endif // __ARM_NEON

float tmp[6][8];

@@ -6721,33 +6945,178 @@ static void conv3x3s1_winograd64_neon4(const Mat& bottom_blob, Mat& top_blob, co
{
for (int j=0; j<outw/6; j++)
{
#if __ARM_NEON
const float* output0_tm0_0 = out0_tm.row(i * w_tm/8 + j);
const float* output0_tm0_4 = out0_tm.row(i * w_tm/8 + j + tiles);
const float* output0_tm1_0 = out0_tm.row(i * w_tm/8 + j + tiles * 2);
const float* output0_tm1_4 = out0_tm.row(i * w_tm/8 + j + tiles * 3);
const float* output0_tm2_0 = out0_tm.row(i * w_tm/8 + j + tiles * 4);
const float* output0_tm2_4 = out0_tm.row(i * w_tm/8 + j + tiles * 5);
const float* output0_tm3_0 = out0_tm.row(i * w_tm/8 + j + tiles * 6);
const float* output0_tm3_4 = out0_tm.row(i * w_tm/8 + j + tiles * 7);
const float* output0_tm4_0 = out0_tm.row(i * w_tm/8 + j + tiles * 8);
const float* output0_tm4_4 = out0_tm.row(i * w_tm/8 + j + tiles * 9);
const float* output0_tm5_0 = out0_tm.row(i * w_tm/8 + j + tiles * 10);
const float* output0_tm5_4 = out0_tm.row(i * w_tm/8 + j + tiles * 11);
const float* output0_tm6_0 = out0_tm.row(i * w_tm/8 + j + tiles * 12);
const float* output0_tm6_4 = out0_tm.row(i * w_tm/8 + j + tiles * 13);
const float* output0_tm7_0 = out0_tm.row(i * w_tm/8 + j + tiles * 14);
const float* output0_tm7_4 = out0_tm.row(i * w_tm/8 + j + tiles * 15);
const float* output0_tm1_0 = out0_tm.row(i * w_tm/8 + j + tiles*2);
const float* output0_tm1_4 = out0_tm.row(i * w_tm/8 + j + tiles*3);
const float* output0_tm2_0 = out0_tm.row(i * w_tm/8 + j + tiles*4);
const float* output0_tm2_4 = out0_tm.row(i * w_tm/8 + j + tiles*5);
const float* output0_tm3_0 = out0_tm.row(i * w_tm/8 + j + tiles*6);
const float* output0_tm3_4 = out0_tm.row(i * w_tm/8 + j + tiles*7);

for (int m=0; m+3<8; m+=4)
{
float32x4_t _output0_tm0_0123 = vld1q_f32(output0_tm0_0);
float32x4_t _output0_tm0_4567 = vld1q_f32(output0_tm0_4);
float32x4_t _output0_tm1_0123 = vld1q_f32(output0_tm1_0);
float32x4_t _output0_tm1_4567 = vld1q_f32(output0_tm1_4);
float32x4_t _output0_tm2_0123 = vld1q_f32(output0_tm2_0);
float32x4_t _output0_tm2_4567 = vld1q_f32(output0_tm2_4);
float32x4_t _output0_tm3_0123 = vld1q_f32(output0_tm3_0);
float32x4_t _output0_tm3_4567 = vld1q_f32(output0_tm3_4);

float32x4x2_t _output0_tm01_00221133 = vtrnq_f32(_output0_tm0_0123, _output0_tm1_0123);
float32x4x2_t _output0_tm01_44665577 = vtrnq_f32(_output0_tm0_4567, _output0_tm1_4567);
float32x4x2_t _output0_tm23_00221133 = vtrnq_f32(_output0_tm2_0123, _output0_tm3_0123);
float32x4x2_t _output0_tm23_44665577 = vtrnq_f32(_output0_tm2_4567, _output0_tm3_4567);

// no vswp intrinsic :(
float32x4_t _output0_tm_00 = vcombine_f32(vget_low_f32(_output0_tm01_00221133.val[0]), vget_low_f32(_output0_tm23_00221133.val[0]));
float32x4_t _output0_tm_11 = vcombine_f32(vget_low_f32(_output0_tm01_00221133.val[1]), vget_low_f32(_output0_tm23_00221133.val[1]));
float32x4_t _output0_tm_22 = vcombine_f32(vget_high_f32(_output0_tm01_00221133.val[0]), vget_high_f32(_output0_tm23_00221133.val[0]));
float32x4_t _output0_tm_33 = vcombine_f32(vget_high_f32(_output0_tm01_00221133.val[1]), vget_high_f32(_output0_tm23_00221133.val[1]));
float32x4_t _output0_tm_44 = vcombine_f32(vget_low_f32(_output0_tm01_44665577.val[0]), vget_low_f32(_output0_tm23_44665577.val[0]));
float32x4_t _output0_tm_55 = vcombine_f32(vget_low_f32(_output0_tm01_44665577.val[1]), vget_low_f32(_output0_tm23_44665577.val[1]));
float32x4_t _output0_tm_66 = vcombine_f32(vget_high_f32(_output0_tm01_44665577.val[0]), vget_high_f32(_output0_tm23_44665577.val[0]));
float32x4_t _output0_tm_77 = vcombine_f32(vget_high_f32(_output0_tm01_44665577.val[1]), vget_high_f32(_output0_tm23_44665577.val[1]));

float32x4_t _tmp024a = vaddq_f32(_output0_tm_11, _output0_tm_22);
float32x4_t _tmp135a = vsubq_f32(_output0_tm_11, _output0_tm_22);

float32x4_t _tmp024b = vaddq_f32(_output0_tm_33, _output0_tm_44);
float32x4_t _tmp135b = vsubq_f32(_output0_tm_33, _output0_tm_44);

float32x4_t _tmp024c = vaddq_f32(_output0_tm_55, _output0_tm_66);
float32x4_t _tmp135c = vsubq_f32(_output0_tm_55, _output0_tm_66);

float32x4_t _tmp0 = vaddq_f32(_output0_tm_00, _tmp024a);
_tmp0 = vmlaq_lane_f32(_tmp0, _tmp024c, vget_high_f32(_coeff), 1);
_tmp0 = vaddq_f32(_tmp0, _tmp024b);

float32x4_t _tmp2 = vmlaq_lane_f32(_tmp024a, _tmp024b, vget_low_f32(_coeff), 0);
_tmp2 = vmlaq_lane_f32(_tmp2, _tmp024c, vget_low_f32(_coeff), 1);

float32x4_t _tmp4 = vmlaq_lane_f32(_tmp024a, _tmp024b, vget_high_f32(_coeff), 0);
_tmp4 = vaddq_f32(_tmp4, _tmp024c);
_tmp4 = vaddq_f32(_tmp4, _tmp024c);

vst1q_f32(&tmp[0][m], _tmp0);
vst1q_f32(&tmp[2][m], _tmp2);
vst1q_f32(&tmp[4][m], _tmp4);

float32x4_t _tmp1 = vmlaq_lane_f32(_tmp135a, _tmp135c, vget_high_f32(_coeff), 0);
_tmp1 = vaddq_f32(_tmp1, _tmp135b);
_tmp1 = vaddq_f32(_tmp1, _tmp135b);

float32x4_t _tmp3 = vmlaq_lane_f32(_tmp135a, _tmp135b, vget_low_f32(_coeff), 1);
_tmp3 = vmlaq_lane_f32(_tmp3, _tmp135c, vget_low_f32(_coeff), 0);

float32x4_t _tmp5 = vaddq_f32(_output0_tm_77, _tmp135a);
_tmp5 = vmlaq_lane_f32(_tmp5, _tmp135b, vget_high_f32(_coeff), 1);
_tmp5 = vaddq_f32(_tmp5, _tmp135c);

vst1q_f32(&tmp[1][m], _tmp1);
vst1q_f32(&tmp[3][m], _tmp3);
vst1q_f32(&tmp[5][m], _tmp5);

output0_tm0_0 += out0_tm.w * tiles * 2*4;
output0_tm0_4 += out0_tm.w * tiles * 2*4;
output0_tm1_0 += out0_tm.w * tiles * 2*4;
output0_tm1_4 += out0_tm.w * tiles * 2*4;
output0_tm2_0 += out0_tm.w * tiles * 2*4;
output0_tm2_4 += out0_tm.w * tiles * 2*4;
output0_tm3_0 += out0_tm.w * tiles * 2*4;
output0_tm3_4 += out0_tm.w * tiles * 2*4;
}

const float* t0 = tmp[0];
const float* t1 = tmp[1];

float* output0 = out0.row(i * 6) + j * 6;
float* output1 = output0 + outw;

const float* output0_tms_0[8] = { output0_tm0_0, output0_tm1_0, output0_tm2_0, output0_tm3_0, output0_tm4_0, output0_tm5_0, output0_tm6_0, output0_tm7_0 };
const float* output0_tms_4[8] = { output0_tm0_4, output0_tm1_4, output0_tm2_4, output0_tm3_4, output0_tm4_4, output0_tm5_4, output0_tm6_4, output0_tm7_4 };
for (int m=0; m+1<6; m+=2)
{
float32x4_t _t0_0123 = vld1q_f32(t0);
float32x4_t _t0_4567 = vld1q_f32(t0+4);
float32x4_t _t1_0123 = vld1q_f32(t1);
float32x4_t _t1_4567 = vld1q_f32(t1+4);

float32x4x2_t _t01_00221133 = vtrnq_f32(_t0_0123, _t1_0123);
float32x4x2_t _t01_44665577 = vtrnq_f32(_t0_4567, _t1_4567);

float32x2_t _t_00 = vget_low_f32(_t01_00221133.val[0]);
float32x2_t _t_11 = vget_low_f32(_t01_00221133.val[1]);
float32x2_t _t_22 = vget_high_f32(_t01_00221133.val[0]);
float32x2_t _t_33 = vget_high_f32(_t01_00221133.val[1]);
float32x2_t _t_44 = vget_low_f32(_t01_44665577.val[0]);
float32x2_t _t_55 = vget_low_f32(_t01_44665577.val[1]);
float32x2_t _t_66 = vget_high_f32(_t01_44665577.val[0]);
float32x2_t _t_77 = vget_high_f32(_t01_44665577.val[1]);

float32x2_t _tmp024a = vadd_f32(_t_11, _t_22);
float32x2_t _tmp135a = vsub_f32(_t_11, _t_22);

float32x2_t _tmp024b = vadd_f32(_t_33, _t_44);
float32x2_t _tmp135b = vsub_f32(_t_33, _t_44);

float32x2_t _tmp024c = vadd_f32(_t_55, _t_66);
float32x2_t _tmp135c = vsub_f32(_t_55, _t_66);

float32x2_t _output_0 = vadd_f32(_t_00, _tmp024a);
_output_0 = vmla_lane_f32(_output_0, _tmp024c, vget_high_f32(_coeff), 1);
_output_0 = vadd_f32(_output_0, _tmp024b);
_output_0 = vadd_f32(_output_0, _bias0);

float32x2_t _output_2 = vmla_lane_f32(_tmp024a, _tmp024b, vget_low_f32(_coeff), 0);
_output_2 = vmla_lane_f32(_output_2, _tmp024c, vget_low_f32(_coeff), 1);
_output_2 = vadd_f32(_output_2, _bias0);

float32x2_t _output_4 = vmla_lane_f32(_tmp024a, _tmp024b, vget_high_f32(_coeff), 0);
_output_4 = vadd_f32(_output_4, _tmp024c);
_output_4 = vadd_f32(_output_4, _tmp024c);
_output_4 = vadd_f32(_output_4, _bias0);

output0[0] = vget_lane_f32(_output_0, 0);
output1[0] = vget_lane_f32(_output_0, 1);
output0[2] = vget_lane_f32(_output_2, 0);
output1[2] = vget_lane_f32(_output_2, 1);
output0[4] = vget_lane_f32(_output_4, 0);
output1[4] = vget_lane_f32(_output_4, 1);

float32x2_t _output_1 = vmla_lane_f32(_tmp135a, _tmp135c, vget_high_f32(_coeff), 0);
_output_1 = vadd_f32(_output_1, _tmp135b);
_output_1 = vadd_f32(_output_1, _tmp135b);
_output_1 = vadd_f32(_output_1, _bias0);

float32x2_t _output_3 = vmla_lane_f32(_tmp135a, _tmp135b, vget_low_f32(_coeff), 1);
_output_3 = vmla_lane_f32(_output_3, _tmp135c, vget_low_f32(_coeff), 0);
_output_3 = vadd_f32(_output_3, _bias0);

float32x2_t _output_5 = vadd_f32(_t_77, _tmp135a);
_output_5 = vmla_lane_f32(_output_5, _tmp135b, vget_high_f32(_coeff), 1);
_output_5 = vadd_f32(_output_5, _tmp135c);
_output_5 = vadd_f32(_output_5, _bias0);

output0[1] = vget_lane_f32(_output_1, 0);
output1[1] = vget_lane_f32(_output_1, 1);
output0[3] = vget_lane_f32(_output_3, 0);
output1[3] = vget_lane_f32(_output_3, 1);
output0[5] = vget_lane_f32(_output_5, 0);
output1[5] = vget_lane_f32(_output_5, 1);

t0 += 8*2;
t1 += 8*2;
output0 += outw*2;
output1 += outw*2;
}
#else
const float* output0_tm_0 = out0_tm.row(i * w_tm/8 + j);
const float* output0_tm_4 = out0_tm.row(i * w_tm/8 + j + tiles);

for (int m=0; m<8; m++)
{
const float* output0_tm_0 = output0_tms_0[m];
const float* output0_tm_4 = output0_tms_4[m];

float tmp024a = output0_tm_0[1] + output0_tm_0[2];
float tmp135a = output0_tm_0[1] - output0_tm_0[2];

@@ -6764,8 +7133,13 @@ static void conv3x3s1_winograd64_neon4(const Mat& bottom_blob, Mat& top_blob, co
tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16;
tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4;
tmp[5][m] = output0_tm_4[3] + tmp135a + tmp135b * 32 + tmp135c;

output0_tm_0 += out0_tm.w * tiles * 2;
output0_tm_4 += out0_tm.w * tiles * 2;
}

float* output0 = out0.row(i * 6) + j * 6;

for (int m=0; m<6; m++)
{
const float* tmp0 = tmp[m];
@@ -6789,6 +7163,7 @@ static void conv3x3s1_winograd64_neon4(const Mat& bottom_blob, Mat& top_blob, co

output0 += outw;
}
#endif // __ARM_NEON
}
}
}


Loading…
Cancel
Save