| @@ -683,6 +683,7 @@ static void conv3x3s1_winograd64_neon(const Mat& bottom_blob, Mat& top_blob, con | |||
| } | |||
| } | |||
| bottom_blob_bordered = Mat(); | |||
| // END transform input | |||
| // BEGIN dot | |||
| @@ -692,8 +693,528 @@ static void conv3x3s1_winograd64_neon(const Mat& bottom_blob, Mat& top_blob, con | |||
| int h_tm = outh / 6 * 8; | |||
| top_blob_tm.create(8*8, w_tm/8 * h_tm/8, outch); | |||
| int nn_outch = outch >> 2; | |||
| int remain_outch_start = nn_outch << 2; | |||
| #pragma omp parallel for | |||
| for (int p = 0; p<outch; p++) | |||
| for (int pp=0; pp<nn_outch; pp++) | |||
| { | |||
| int p = pp * 4; | |||
| Mat out0_tm = top_blob_tm.channel(p); | |||
| Mat out1_tm = top_blob_tm.channel(p+1); | |||
| Mat out2_tm = top_blob_tm.channel(p+2); | |||
| Mat out3_tm = top_blob_tm.channel(p+3); | |||
| const Mat kernel0_tm = kernel_tm.channel(p); | |||
| const Mat kernel1_tm = kernel_tm.channel(p+1); | |||
| const Mat kernel2_tm = kernel_tm.channel(p+2); | |||
| const Mat kernel3_tm = kernel_tm.channel(p+3); | |||
| out0_tm.fill(0.f); | |||
| out1_tm.fill(0.f); | |||
| out2_tm.fill(0.f); | |||
| out3_tm.fill(0.f); | |||
| int q = 0; | |||
| for (; q+3<inch; q+=4) | |||
| { | |||
| const float* r0 = bottom_blob_tm.channel(q); | |||
| const float* r1 = bottom_blob_tm.channel(q+1); | |||
| const float* r2 = bottom_blob_tm.channel(q+2); | |||
| const float* r3 = bottom_blob_tm.channel(q+3); | |||
| const float* k00 = kernel0_tm.row(q); | |||
| const float* k10 = kernel1_tm.row(q); | |||
| const float* k20 = kernel2_tm.row(q); | |||
| const float* k30 = kernel3_tm.row(q); | |||
| float* output0_tm = out0_tm; | |||
| float* output1_tm = out1_tm; | |||
| float* output2_tm = out2_tm; | |||
| float* output3_tm = out3_tm; | |||
| // tile | |||
| for (int i=0; i<h_tm/8 * w_tm/8; i++) | |||
| { | |||
| #if __ARM_NEON | |||
| #if __aarch64__ | |||
| for (int m=0; m+7<64; m+=8) | |||
| { | |||
| float32x4_t _output0_tm = vld1q_f32(output0_tm); | |||
| float32x4_t _output1_tm = vld1q_f32(output1_tm); | |||
| float32x4_t _output2_tm = vld1q_f32(output2_tm); | |||
| float32x4_t _output3_tm = vld1q_f32(output3_tm); | |||
| float32x4_t _r0 = vld1q_f32(r0); | |||
| float32x4_t _r1 = vld1q_f32(r1); | |||
| float32x4_t _r2 = vld1q_f32(r2); | |||
| float32x4_t _r3 = vld1q_f32(r3); | |||
| float32x4_t _k00 = vld1q_f32(k00); | |||
| k00 += 64; | |||
| float32x4_t _k01 = vld1q_f32(k00); | |||
| k00 += 64; | |||
| float32x4_t _k02 = vld1q_f32(k00); | |||
| k00 += 64; | |||
| float32x4_t _k03 = vld1q_f32(k00); | |||
| k00 += 64; | |||
| k00 -= 64*4; | |||
| _output0_tm = vmlaq_f32(_output0_tm, _r0, _k00); | |||
| _output0_tm = vmlaq_f32(_output0_tm, _r1, _k01); | |||
| _output0_tm = vmlaq_f32(_output0_tm, _r2, _k02); | |||
| _output0_tm = vmlaq_f32(_output0_tm, _r3, _k03); | |||
| float32x4_t _k10 = vld1q_f32(k10); | |||
| k10 += 64; | |||
| float32x4_t _k11 = vld1q_f32(k10); | |||
| k10 += 64; | |||
| float32x4_t _k12 = vld1q_f32(k10); | |||
| k10 += 64; | |||
| float32x4_t _k13 = vld1q_f32(k10); | |||
| k10 += 64; | |||
| k10 -= 64*4; | |||
| _output1_tm = vmlaq_f32(_output1_tm, _r0, _k10); | |||
| _output1_tm = vmlaq_f32(_output1_tm, _r1, _k11); | |||
| _output1_tm = vmlaq_f32(_output1_tm, _r2, _k12); | |||
| _output1_tm = vmlaq_f32(_output1_tm, _r3, _k13); | |||
| float32x4_t _k20 = vld1q_f32(k20); | |||
| k20 += 64; | |||
| float32x4_t _k21 = vld1q_f32(k20); | |||
| k20 += 64; | |||
| float32x4_t _k22 = vld1q_f32(k20); | |||
| k20 += 64; | |||
| float32x4_t _k23 = vld1q_f32(k20); | |||
| k20 += 64; | |||
| k20 -= 64*4; | |||
| _output2_tm = vmlaq_f32(_output2_tm, _r0, _k20); | |||
| _output2_tm = vmlaq_f32(_output2_tm, _r1, _k21); | |||
| _output2_tm = vmlaq_f32(_output2_tm, _r2, _k22); | |||
| _output2_tm = vmlaq_f32(_output2_tm, _r3, _k23); | |||
| float32x4_t _k30 = vld1q_f32(k30); | |||
| k30 += 64; | |||
| float32x4_t _k31 = vld1q_f32(k30); | |||
| k30 += 64; | |||
| float32x4_t _k32 = vld1q_f32(k30); | |||
| k30 += 64; | |||
| float32x4_t _k33 = vld1q_f32(k30); | |||
| k30 += 64; | |||
| k30 -= 64*4; | |||
| _output3_tm = vmlaq_f32(_output3_tm, _r0, _k30); | |||
| _output3_tm = vmlaq_f32(_output3_tm, _r1, _k31); | |||
| _output3_tm = vmlaq_f32(_output3_tm, _r2, _k32); | |||
| _output3_tm = vmlaq_f32(_output3_tm, _r3, _k33); | |||
| vst1q_f32(output0_tm, _output0_tm); | |||
| vst1q_f32(output1_tm, _output1_tm); | |||
| vst1q_f32(output2_tm, _output2_tm); | |||
| vst1q_f32(output3_tm, _output3_tm); | |||
| output0_tm += 4; | |||
| output1_tm += 4; | |||
| output2_tm += 4; | |||
| output3_tm += 4; | |||
| r0 += 4; | |||
| r1 += 4; | |||
| r2 += 4; | |||
| r3 += 4; | |||
| k00 += 4; | |||
| k10 += 4; | |||
| k20 += 4; | |||
| k30 += 4; | |||
| float32x4_t _output0_tmn = vld1q_f32(output0_tm); | |||
| float32x4_t _output1_tmn = vld1q_f32(output1_tm); | |||
| float32x4_t _output2_tmn = vld1q_f32(output2_tm); | |||
| float32x4_t _output3_tmn = vld1q_f32(output3_tm); | |||
| float32x4_t _r0n = vld1q_f32(r0); | |||
| float32x4_t _r1n = vld1q_f32(r1); | |||
| float32x4_t _r2n = vld1q_f32(r2); | |||
| float32x4_t _r3n = vld1q_f32(r3); | |||
| float32x4_t _k00n = vld1q_f32(k00); | |||
| k00 += 64; | |||
| float32x4_t _k01n = vld1q_f32(k00); | |||
| k00 += 64; | |||
| float32x4_t _k02n = vld1q_f32(k00); | |||
| k00 += 64; | |||
| float32x4_t _k03n = vld1q_f32(k00); | |||
| k00 += 64; | |||
| k00 -= 64*4; | |||
| _output0_tmn = vmlaq_f32(_output0_tmn, _r0n, _k00n); | |||
| _output0_tmn = vmlaq_f32(_output0_tmn, _r1n, _k01n); | |||
| _output0_tmn = vmlaq_f32(_output0_tmn, _r2n, _k02n); | |||
| _output0_tmn = vmlaq_f32(_output0_tmn, _r3n, _k03n); | |||
| float32x4_t _k10n = vld1q_f32(k10); | |||
| k10 += 64; | |||
| float32x4_t _k11n = vld1q_f32(k10); | |||
| k10 += 64; | |||
| float32x4_t _k12n = vld1q_f32(k10); | |||
| k10 += 64; | |||
| float32x4_t _k13n = vld1q_f32(k10); | |||
| k10 += 64; | |||
| k10 -= 64*4; | |||
| _output1_tmn = vmlaq_f32(_output1_tmn, _r0n, _k10n); | |||
| _output1_tmn = vmlaq_f32(_output1_tmn, _r1n, _k11n); | |||
| _output1_tmn = vmlaq_f32(_output1_tmn, _r2n, _k12n); | |||
| _output1_tmn = vmlaq_f32(_output1_tmn, _r3n, _k13n); | |||
| float32x4_t _k20n = vld1q_f32(k20); | |||
| k20 += 64; | |||
| float32x4_t _k21n = vld1q_f32(k20); | |||
| k20 += 64; | |||
| float32x4_t _k22n = vld1q_f32(k20); | |||
| k20 += 64; | |||
| float32x4_t _k23n = vld1q_f32(k20); | |||
| k20 += 64; | |||
| k20 -= 64*4; | |||
| _output2_tmn = vmlaq_f32(_output2_tmn, _r0n, _k20n); | |||
| _output2_tmn = vmlaq_f32(_output2_tmn, _r1n, _k21n); | |||
| _output2_tmn = vmlaq_f32(_output2_tmn, _r2n, _k22n); | |||
| _output2_tmn = vmlaq_f32(_output2_tmn, _r3n, _k23n); | |||
| float32x4_t _k30n = vld1q_f32(k30); | |||
| k30 += 64; | |||
| float32x4_t _k31n = vld1q_f32(k30); | |||
| k30 += 64; | |||
| float32x4_t _k32n = vld1q_f32(k30); | |||
| k30 += 64; | |||
| float32x4_t _k33n = vld1q_f32(k30); | |||
| k30 += 64; | |||
| k30 -= 64*4; | |||
| _output3_tmn = vmlaq_f32(_output3_tmn, _r0n, _k30n); | |||
| _output3_tmn = vmlaq_f32(_output3_tmn, _r1n, _k31n); | |||
| _output3_tmn = vmlaq_f32(_output3_tmn, _r2n, _k32n); | |||
| _output3_tmn = vmlaq_f32(_output3_tmn, _r3n, _k33n); | |||
| vst1q_f32(output0_tm, _output0_tmn); | |||
| vst1q_f32(output1_tm, _output1_tmn); | |||
| vst1q_f32(output2_tm, _output2_tmn); | |||
| vst1q_f32(output3_tm, _output3_tmn); | |||
| output0_tm += 4; | |||
| output1_tm += 4; | |||
| output2_tm += 4; | |||
| output3_tm += 4; | |||
| r0 += 4; | |||
| r1 += 4; | |||
| r2 += 4; | |||
| r3 += 4; | |||
| k00 += 4; | |||
| k10 += 4; | |||
| k20 += 4; | |||
| k30 += 4; | |||
| } | |||
| #else // __aarch64__ | |||
| asm volatile( | |||
| "mov r4, #8 \n" | |||
| "pld [%0, #256] \n" | |||
| "vld1.f32 {d16-d19}, [%0 :128]\n"//q8 q9 = _output0_tm | |||
| "0: \n" | |||
| "pld [%4, #256] \n" | |||
| "vld1.f32 {d0-d3}, [%4 :128]! \n"//q0 q1 = _r0 | |||
| "pld [%8, #256] \n" | |||
| "vld1.f32 {d20-d23}, [%8 :128]\n"//q10 q11 = _k00 | |||
| "add %8, %8, #256 \n" | |||
| "vmla.f32 q8, q0, q10 \n" | |||
| "vmla.f32 q9, q1, q11 \n" | |||
| "pld [%1, #256] \n" | |||
| "vld1.f32 {d24-d27}, [%1 :128]\n"//q12 q13 = _output1_tm | |||
| "pld [%9, #256] \n" | |||
| "vld1.f32 {d28-d31}, [%9 :128]\n"//q14 q15 = _k10 | |||
| "add %9, %9, #256 \n" | |||
| "vmla.f32 q12, q0, q14 \n" | |||
| "vmla.f32 q13, q1, q15 \n" | |||
| "pld [%5, #256] \n" | |||
| "vld1.f32 {d4-d7}, [%5 :128]! \n"//q2 q3 = _r1 | |||
| "pld [%8, #256] \n" | |||
| "vld1.f32 {d20-d23}, [%8 :128]\n"//q10 q11 = _k01 | |||
| "add %8, %8, #256 \n" | |||
| "vmla.f32 q8, q2, q10 \n" | |||
| "vmla.f32 q9, q3, q11 \n" | |||
| "pld [%9, #256] \n" | |||
| "vld1.f32 {d28-d31}, [%9 :128]\n"//q14 q15 = _k11 | |||
| "add %9, %9, #256 \n" | |||
| "vmla.f32 q12, q2, q14 \n" | |||
| "vmla.f32 q13, q3, q15 \n" | |||
| "pld [%6, #256] \n" | |||
| "vld1.f32 {d8-d11}, [%6 :128]!\n"//q4 q5 = _r2 | |||
| "pld [%8, #256] \n" | |||
| "vld1.f32 {d20-d23}, [%8 :128]\n"//q10 q11 = _k02 | |||
| "add %8, %8, #256 \n" | |||
| "vmla.f32 q8, q4, q10 \n" | |||
| "vmla.f32 q9, q5, q11 \n" | |||
| "pld [%9, #256] \n" | |||
| "vld1.f32 {d28-d31}, [%9 :128]\n"//q14 q15 = _k12 | |||
| "add %9, %9, #256 \n" | |||
| "vmla.f32 q12, q4, q14 \n" | |||
| "vmla.f32 q13, q5, q15 \n" | |||
| "pld [%7, #256] \n" | |||
| "vld1.f32 {d12-d15}, [%7 :128]!\n"//q6 q7 = _r3 | |||
| "pld [%8, #256] \n" | |||
| "vld1.f32 {d20-d23}, [%8 :128]\n"//q10 q11 = _k03 | |||
| "sub %8, %8, #736 \n" | |||
| "vmla.f32 q8, q6, q10 \n" | |||
| "vmla.f32 q9, q7, q11 \n" | |||
| "pld [%9, #256] \n" | |||
| "vld1.f32 {d28-d31}, [%9 :128]\n"//q14 q15 = _k13 | |||
| "sub %9, %9, #736 \n" | |||
| "vmla.f32 q12, q6, q14 \n" | |||
| "vmla.f32 q13, q7, q15 \n" | |||
| "vst1.f32 {d16-d19}, [%0 :128]!\n" | |||
| "pld [%2, #256] \n" | |||
| "vld1.f32 {d16-d19}, [%2 :128]\n"//q8 q9 = _output2_tm | |||
| "pld [%10, #256] \n" | |||
| "vld1.f32 {d20-d23}, [%10 :128]\n"//q10 q11 = _k20 | |||
| "add %10, %10, #256 \n" | |||
| "vmla.f32 q8, q0, q10 \n" | |||
| "vmla.f32 q9, q1, q11 \n" | |||
| "vst1.f32 {d24-d27}, [%1 :128]!\n" | |||
| "pld [%3, #256] \n" | |||
| "vld1.f32 {d24-d27}, [%3 :128]\n"//q12 q13 = _output3_tm | |||
| "pld [%11, #256] \n" | |||
| "vld1.f32 {d28-d31}, [%11 :128]\n"//q14 q15 = _k30 | |||
| "add %11, %11, #256 \n" | |||
| "vmla.f32 q12, q0, q14 \n" | |||
| "vmla.f32 q13, q1, q15 \n" | |||
| "pld [%10, #256] \n" | |||
| "vld1.f32 {d20-d23}, [%10 :128]\n"//q10 q11 = _k21 | |||
| "add %10, %10, #256 \n" | |||
| "vmla.f32 q8, q2, q10 \n" | |||
| "vmla.f32 q9, q3, q11 \n" | |||
| "pld [%11, #256] \n" | |||
| "vld1.f32 {d28-d31}, [%11 :128]\n"//q14 q15 = _k31 | |||
| "add %11, %11, #256 \n" | |||
| "vmla.f32 q12, q2, q14 \n" | |||
| "vmla.f32 q13, q3, q15 \n" | |||
| "pld [%10, #256] \n" | |||
| "vld1.f32 {d20-d23}, [%10 :128]\n"//q10 q11 = _k22 | |||
| "add %10, %10, #256 \n" | |||
| "vmla.f32 q8, q4, q10 \n" | |||
| "vmla.f32 q9, q5, q11 \n" | |||
| "pld [%11, #256] \n" | |||
| "vld1.f32 {d28-d31}, [%11 :128]\n"//q14 q15 = _k32 | |||
| "add %11, %11, #256 \n" | |||
| "vmla.f32 q12, q4, q14 \n" | |||
| "vmla.f32 q13, q5, q15 \n" | |||
| "pld [%10, #256] \n" | |||
| "vld1.f32 {d20-d23}, [%10 :128]\n"//q10 q11 = _k23 | |||
| "sub %10, %10, #736 \n" | |||
| "vmla.f32 q8, q6, q10 \n" | |||
| "vmla.f32 q9, q7, q11 \n" | |||
| "pld [%11, #256] \n" | |||
| "vld1.f32 {d28-d31}, [%11 :128]\n"//q14 q15 = _k33 | |||
| "sub %11, %11, #736 \n" | |||
| "vmla.f32 q12, q6, q14 \n" | |||
| "vmla.f32 q13, q7, q15 \n" | |||
| "vst1.f32 {d16-d19}, [%2 :128]!\n" | |||
| "pld [%0, #256] \n" | |||
| "vld1.f32 {d16-d19}, [%0 :128]\n"//q8 q9 = _output0_tm | |||
| "subs r4, r4, #1 \n" | |||
| "vst1.f32 {d24-d27}, [%3 :128]!\n" | |||
| "bne 0b \n" | |||
| : "=r"(output0_tm), // %0 | |||
| "=r"(output1_tm), // %1 | |||
| "=r"(output2_tm), // %2 | |||
| "=r"(output3_tm), // %3 | |||
| "=r"(r0), // %4 | |||
| "=r"(r1), // %5 | |||
| "=r"(r2), // %6 | |||
| "=r"(r3), // %7 | |||
| "=r"(k00), // %8 | |||
| "=r"(k10), // %9 | |||
| "=r"(k20), // %10 | |||
| "=r"(k30) // %11 | |||
| : "0"(output0_tm), | |||
| "1"(output1_tm), | |||
| "2"(output2_tm), | |||
| "3"(output3_tm), | |||
| "4"(r0), | |||
| "5"(r1), | |||
| "6"(r2), | |||
| "7"(r3), | |||
| "8"(k00), | |||
| "9"(k10), | |||
| "10"(k20), | |||
| "11"(k30) | |||
| : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15" | |||
| ); | |||
| #endif // __aarch64__ | |||
| k00 -= 64; | |||
| k10 -= 64; | |||
| k20 -= 64; | |||
| k30 -= 64; | |||
| #else | |||
| for (int m=0; m<64; m++) | |||
| { | |||
| output0_tm[m] += r0[m] * k00[m]; | |||
| k00 += 64; | |||
| output0_tm[m] += r1[m] * k00[m]; | |||
| k00 += 64; | |||
| output0_tm[m] += r2[m] * k00[m]; | |||
| k00 += 64; | |||
| output0_tm[m] += r3[m] * k00[m]; | |||
| k00 += 64; | |||
| k00 -= 64 * 4; | |||
| output1_tm[m] += r0[m] * k10[m]; | |||
| k10 += 64; | |||
| output1_tm[m] += r1[m] * k10[m]; | |||
| k10 += 64; | |||
| output1_tm[m] += r2[m] * k10[m]; | |||
| k10 += 64; | |||
| output1_tm[m] += r3[m] * k10[m]; | |||
| k10 += 64; | |||
| k10 -= 64 * 4; | |||
| output2_tm[m] += r0[m] * k20[m]; | |||
| k20 += 64; | |||
| output2_tm[m] += r1[m] * k20[m]; | |||
| k20 += 64; | |||
| output2_tm[m] += r2[m] * k20[m]; | |||
| k20 += 64; | |||
| output2_tm[m] += r3[m] * k20[m]; | |||
| k20 += 64; | |||
| k20 -= 64 * 4; | |||
| output3_tm[m] += r0[m] * k30[m]; | |||
| k30 += 64; | |||
| output3_tm[m] += r1[m] * k30[m]; | |||
| k30 += 64; | |||
| output3_tm[m] += r2[m] * k30[m]; | |||
| k30 += 64; | |||
| output3_tm[m] += r3[m] * k30[m]; | |||
| k30 += 64; | |||
| k30 -= 64 * 4; | |||
| } | |||
| r0 += 64; | |||
| r1 += 64; | |||
| r2 += 64; | |||
| r3 += 64; | |||
| output0_tm += 64; | |||
| output1_tm += 64; | |||
| output2_tm += 64; | |||
| output3_tm += 64; | |||
| #endif // __ARM_NEON | |||
| } | |||
| } | |||
| for (; q<inch; q++) | |||
| { | |||
| const float* r0 = bottom_blob_tm.channel(q); | |||
| const float* k0 = kernel0_tm.row(q); | |||
| const float* k1 = kernel1_tm.row(q); | |||
| const float* k2 = kernel2_tm.row(q); | |||
| const float* k3 = kernel3_tm.row(q); | |||
| float* output0_tm = out0_tm; | |||
| float* output1_tm = out1_tm; | |||
| float* output2_tm = out2_tm; | |||
| float* output3_tm = out3_tm; | |||
| // tile | |||
| for (int i=0; i<h_tm/8 * w_tm/8; i++) | |||
| { | |||
| // TODO neon optimize | |||
| for (int m=0; m<64; m++) | |||
| { | |||
| output0_tm[m] += r0[m] * k0[m]; | |||
| output1_tm[m] += r0[m] * k1[m]; | |||
| output2_tm[m] += r0[m] * k2[m]; | |||
| output3_tm[m] += r0[m] * k3[m]; | |||
| } | |||
| r0 += 64; | |||
| output0_tm += 64; | |||
| output1_tm += 64; | |||
| output2_tm += 64; | |||
| output3_tm += 64; | |||
| } | |||
| } | |||
| } | |||
| #pragma omp parallel for | |||
| for (int p=remain_outch_start; p<outch; p++) | |||
| { | |||
| Mat out0_tm = top_blob_tm.channel(p); | |||
| const Mat kernel0_tm = kernel_tm.channel(p); | |||
| @@ -1421,6 +1942,7 @@ static void conv3x3s1_winograd64_neon2(const Mat& bottom_blob, Mat& top_blob, co | |||
| } | |||
| } | |||
| bottom_blob_bordered = Mat(); | |||
| // END transform input | |||
| // BEGIN dot | |||