diff --git a/src/layer/x86/convolution_3x3_winograd.h b/src/layer/x86/convolution_3x3_winograd.h index bb068cd3d..c1ee37582 100644 --- a/src/layer/x86/convolution_3x3_winograd.h +++ b/src/layer/x86/convolution_3x3_winograd.h @@ -116,8 +116,9 @@ static void pack_A_tile(const Mat& A, Mat& AT, int batch, int max_ii, int max_kk } } -static void transpose_pack_B_tile(const Mat& B, Mat& BT, int batch, int max_jj, int max_kk) +static void transpose_pack_B_tile(const Mat& B, Mat& BT, int batch, int max_jj, int max_kk, int nT) { + #pragma omp parallel for num_threads(nT) for (int b = 0; b < batch; b++) { float* pp = BT.row(b); @@ -126,28 +127,26 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int batch, int max_jj, #if __SSE2__ for (; jj + 11 < max_jj; jj += 12) { - int N = batch; const float* p0 = B; int kk = 0; #if __AVX__ #if __AVX512F__ - N = batch * 16; - p0 += jj * N + b * 16; + p0 += (b * max_jj + jj) * 16; for (; kk + 15 < max_kk; kk += 16) { __m512 _r0 = _mm512_load_ps(p0); - __m512 _r1 = _mm512_load_ps(p0 + N); - __m512 _r2 = _mm512_load_ps(p0 + 2 * N); - __m512 _r3 = _mm512_load_ps(p0 + 3 * N); - __m512 _r4 = _mm512_load_ps(p0 + 4 * N); - __m512 _r5 = _mm512_load_ps(p0 + 5 * N); - __m512 _r6 = _mm512_load_ps(p0 + 6 * N); - __m512 _r7 = _mm512_load_ps(p0 + 7 * N); - __m512 _r8 = _mm512_load_ps(p0 + 8 * N); - __m512 _r9 = _mm512_load_ps(p0 + 9 * N); - __m512 _ra = _mm512_load_ps(p0 + 10 * N); - __m512 _rb = _mm512_load_ps(p0 + 11 * N); + __m512 _r1 = _mm512_load_ps(p0 + 16); + __m512 _r2 = _mm512_load_ps(p0 + 16 * 2); + __m512 _r3 = _mm512_load_ps(p0 + 16 * 3); + __m512 _r4 = _mm512_load_ps(p0 + 16 * 4); + __m512 _r5 = _mm512_load_ps(p0 + 16 * 5); + __m512 _r6 = _mm512_load_ps(p0 + 16 * 6); + __m512 _r7 = _mm512_load_ps(p0 + 16 * 7); + __m512 _r8 = _mm512_load_ps(p0 + 16 * 8); + __m512 _r9 = _mm512_load_ps(p0 + 16 * 9); + __m512 _ra = _mm512_load_ps(p0 + 16 * 10); + __m512 _rb = _mm512_load_ps(p0 + 16 * 11); transpose16x12_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb); _mm512_storeu_ps(pp, _r0); _mm512_storeu_ps(pp + 16, _r1); @@ -161,27 +160,26 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int batch, int max_jj, _mm512_storeu_ps(pp + 16 * 9, _r9); _mm512_storeu_ps(pp + 16 * 10, _ra); _mm512_storeu_ps(pp + 16 * 11, _rb); - p0 += max_jj * N; + p0 += max_jj * batch * 16; pp += 192; } - p0 -= jj * N + b * 16; + p0 -= (b * max_jj + jj) * 16; #endif // __AVX512F__ - N = batch * 8; - p0 += jj * N + b * 8; + p0 += (b * max_jj + jj) * 8; for (; kk + 7 < max_kk; kk += 8) { __m256 _r0 = _mm256_load_ps(p0); - __m256 _r1 = _mm256_load_ps(p0 + N); - __m256 _r2 = _mm256_load_ps(p0 + 2 * N); - __m256 _r3 = _mm256_load_ps(p0 + 3 * N); - __m256 _r4 = _mm256_load_ps(p0 + 4 * N); - __m256 _r5 = _mm256_load_ps(p0 + 5 * N); - __m256 _r6 = _mm256_load_ps(p0 + 6 * N); - __m256 _r7 = _mm256_load_ps(p0 + 7 * N); - __m256 _r8 = _mm256_load_ps(p0 + 8 * N); - __m256 _r9 = _mm256_load_ps(p0 + 9 * N); - __m256 _ra = _mm256_load_ps(p0 + 10 * N); - __m256 _rb = _mm256_load_ps(p0 + 11 * N); + __m256 _r1 = _mm256_load_ps(p0 + 8); + __m256 _r2 = _mm256_load_ps(p0 + 8 * 2); + __m256 _r3 = _mm256_load_ps(p0 + 8 * 3); + __m256 _r4 = _mm256_load_ps(p0 + 8 * 4); + __m256 _r5 = _mm256_load_ps(p0 + 8 * 5); + __m256 _r6 = _mm256_load_ps(p0 + 8 * 6); + __m256 _r7 = _mm256_load_ps(p0 + 8 * 7); + __m256 _r8 = _mm256_load_ps(p0 + 8 * 8); + __m256 _r9 = _mm256_load_ps(p0 + 8 * 9); + __m256 _ra = _mm256_load_ps(p0 + 8 * 10); + __m256 _rb = _mm256_load_ps(p0 + 8 * 11); transpose8x12_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb); _mm256_storeu_ps(pp, _r0); _mm256_storeu_ps(pp + 8, _r1); @@ -195,27 +193,26 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int batch, int max_jj, _mm256_storeu_ps(pp + 8 * 9, _r9); _mm256_storeu_ps(pp + 8 * 10, _ra); _mm256_storeu_ps(pp + 8 * 11, _rb); - p0 += max_jj * N; + p0 += max_jj * batch * 8; pp += 96; } - p0 -= jj * N + b * 8; + p0 -= (b * max_jj + jj) * 8; #endif // __AVX__ - N = batch * 4; - p0 += jj * N + b * 4; + p0 += (b * max_jj + jj) * 4; for (; kk + 3 < max_kk; kk += 4) { __m128 _r0 = _mm_load_ps(p0); - __m128 _r1 = _mm_load_ps(p0 + N); - __m128 _r2 = _mm_load_ps(p0 + 2 * N); - __m128 _r3 = _mm_load_ps(p0 + 3 * N); - __m128 _r4 = _mm_load_ps(p0 + 4 * N); - __m128 _r5 = _mm_load_ps(p0 + 5 * N); - __m128 _r6 = _mm_load_ps(p0 + 6 * N); - __m128 _r7 = _mm_load_ps(p0 + 7 * N); - __m128 _r8 = _mm_load_ps(p0 + 8 * N); - __m128 _r9 = _mm_load_ps(p0 + 9 * N); - __m128 _ra = _mm_load_ps(p0 + 10 * N); - __m128 _rb = _mm_load_ps(p0 + 11 * N); + __m128 _r1 = _mm_load_ps(p0 + 4); + __m128 _r2 = _mm_load_ps(p0 + 4 * 2); + __m128 _r3 = _mm_load_ps(p0 + 4 * 3); + __m128 _r4 = _mm_load_ps(p0 + 4 * 4); + __m128 _r5 = _mm_load_ps(p0 + 4 * 5); + __m128 _r6 = _mm_load_ps(p0 + 4 * 6); + __m128 _r7 = _mm_load_ps(p0 + 4 * 7); + __m128 _r8 = _mm_load_ps(p0 + 4 * 8); + __m128 _r9 = _mm_load_ps(p0 + 4 * 9); + __m128 _ra = _mm_load_ps(p0 + 4 * 10); + __m128 _rb = _mm_load_ps(p0 + 4 * 11); _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); _MM_TRANSPOSE4_PS(_r4, _r5, _r6, _r7); _MM_TRANSPOSE4_PS(_r8, _r9, _ra, _rb); @@ -231,83 +228,76 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int batch, int max_jj, _mm_store_ps(pp + 4 * 9, _r3); _mm_store_ps(pp + 4 * 10, _r7); _mm_store_ps(pp + 4 * 11, _rb); - p0 += max_jj * N; + p0 += max_jj * batch * 4; pp += 48; } - p0 -= jj * N + b * 4; - N = batch * 2; - p0 += jj * N + b * 2; + p0 -= (b * max_jj + jj) * 2; for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; - pp[1] = p0[N]; - pp[2] = p0[2 * N]; - pp[3] = p0[3 * N]; - pp[4] = p0[4 * N]; - pp[5] = p0[5 * N]; - pp[6] = p0[6 * N]; - pp[7] = p0[7 * N]; - pp[8] = p0[8 * N]; - pp[9] = p0[9 * N]; - pp[10] = p0[10 * N]; - pp[11] = p0[11 * N]; + pp[1] = p0[2]; + pp[2] = p0[2 * 2]; + pp[3] = p0[3 * 2]; + pp[4] = p0[4 * 2]; + pp[5] = p0[5 * 2]; + pp[6] = p0[6 * 2]; + pp[7] = p0[7 * 2]; + pp[8] = p0[8 * 2]; + pp[9] = p0[9 * 2]; + pp[10] = p0[10 * 2]; + pp[11] = p0[11 * 2]; pp[12] = p0[1]; - pp[13] = p0[N + 1]; - pp[14] = p0[2 * N + 1]; - pp[15] = p0[3 * N + 1]; - pp[16] = p0[4 * N + 1]; - pp[17] = p0[5 * N + 1]; - pp[18] = p0[6 * N + 1]; - pp[19] = p0[7 * N + 1]; - pp[20] = p0[8 * N + 1]; - pp[21] = p0[9 * N + 1]; - pp[22] = p0[10 * N + 1]; - pp[23] = p0[11 * N + 1]; - p0 += max_jj * N; + pp[13] = p0[2 + 1]; + pp[14] = p0[2 * 2 + 1]; + pp[15] = p0[3 * 2 + 1]; + pp[16] = p0[4 * 2 + 1]; + pp[17] = p0[5 * 2 + 1]; + pp[18] = p0[6 * 2 + 1]; + pp[19] = p0[7 * 2 + 1]; + pp[20] = p0[8 * 2 + 1]; + pp[21] = p0[9 * 2 + 1]; + pp[22] = p0[10 * 2 + 1]; + pp[23] = p0[11 * 2 + 1]; + p0 += max_jj * batch * 2; pp += 24; } - p0 -= jj * N + b * 2; - N = batch; - p0 += jj * N + b; + p0 -= (b * max_jj + jj); for (; kk < max_kk; kk++) { pp[0] = p0[0]; - pp[1] = p0[N]; - pp[2] = p0[2 * N]; - pp[3] = p0[3 * N]; - pp[4] = p0[4 * N]; - pp[5] = p0[5 * N]; - pp[6] = p0[6 * N]; - pp[7] = p0[7 * N]; - pp[8] = p0[8 * N]; - pp[9] = p0[9 * N]; - pp[10] = p0[10 * N]; - pp[11] = p0[11 * N]; - p0 += max_jj * N; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p0[4]; + pp[5] = p0[5]; + pp[6] = p0[6]; + pp[7] = p0[7]; + pp[8] = p0[8]; + pp[9] = p0[9]; + pp[10] = p0[10]; + pp[11] = p0[11]; + p0 += max_jj * batch; pp += 12; } - p0 -= jj * N + b; } for (; jj + 7 < max_jj; jj += 8) { - int N = batch; const float* p0 = B; int kk = 0; #if __AVX__ #if __AVX512F__ - N = batch * 16; - p0 += jj * N + b * 16; + p0 += (b * max_jj + jj) * 16; for (; kk + 15 < max_kk; kk += 16) { __m512 _r0 = _mm512_load_ps(p0); - __m512 _r1 = _mm512_load_ps(p0 + N); - __m512 _r2 = _mm512_load_ps(p0 + 2 * N); - __m512 _r3 = _mm512_load_ps(p0 + 3 * N); - __m512 _r4 = _mm512_load_ps(p0 + 4 * N); - __m512 _r5 = _mm512_load_ps(p0 + 5 * N); - __m512 _r6 = _mm512_load_ps(p0 + 6 * N); - __m512 _r7 = _mm512_load_ps(p0 + 7 * N); + __m512 _r1 = _mm512_load_ps(p0 + 16); + __m512 _r2 = _mm512_load_ps(p0 + 16 * 2); + __m512 _r3 = _mm512_load_ps(p0 + 16 * 3); + __m512 _r4 = _mm512_load_ps(p0 + 16 * 4); + __m512 _r5 = _mm512_load_ps(p0 + 16 * 5); + __m512 _r6 = _mm512_load_ps(p0 + 16 * 6); + __m512 _r7 = _mm512_load_ps(p0 + 16 * 7); transpose16x8_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); _mm512_storeu_ps(pp, _r0); _mm512_storeu_ps(pp + 16, _r1); @@ -317,49 +307,47 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int batch, int max_jj, _mm512_storeu_ps(pp + 16 * 5, _r5); _mm512_storeu_ps(pp + 16 * 6, _r6); _mm512_storeu_ps(pp + 16 * 7, _r7); - p0 += max_jj * N; + p0 += max_jj * batch * 16; pp += 128; } - p0 -= jj * N + b * 16; + p0 -= (b * max_jj + jj) * 16; #endif // __AVX512F__ - N = batch * 8; - p0 += jj * N + b * 8; + p0 += (b * max_jj + jj) * 8; for (; kk + 7 < max_kk; kk += 8) { __m256 _r0 = _mm256_load_ps(p0); - __m256 _r1 = _mm256_load_ps(p0 + N); - __m256 _r2 = _mm256_load_ps(p0 + 2 * N); - __m256 _r3 = _mm256_load_ps(p0 + 3 * N); - __m256 _r4 = _mm256_load_ps(p0 + 4 * N); - __m256 _r5 = _mm256_load_ps(p0 + 5 * N); - __m256 _r6 = _mm256_load_ps(p0 + 6 * N); - __m256 _r7 = _mm256_load_ps(p0 + 7 * N); + __m256 _r1 = _mm256_load_ps(p0 + 8); + __m256 _r2 = _mm256_load_ps(p0 + 8 * 2); + __m256 _r3 = _mm256_load_ps(p0 + 8 * 3); + __m256 _r4 = _mm256_load_ps(p0 + 8 * 4); + __m256 _r5 = _mm256_load_ps(p0 + 8 * 5); + __m256 _r6 = _mm256_load_ps(p0 + 8 * 6); + __m256 _r7 = _mm256_load_ps(p0 + 8 * 7); transpose8x8_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); _mm256_storeu_ps(pp, _r0); _mm256_storeu_ps(pp + 8, _r1); - _mm256_storeu_ps(pp + 16, _r2); - _mm256_storeu_ps(pp + 24, _r3); - _mm256_storeu_ps(pp + 32, _r4); - _mm256_storeu_ps(pp + 40, _r5); - _mm256_storeu_ps(pp + 48, _r6); - _mm256_storeu_ps(pp + 56, _r7); - p0 += max_jj * N; + _mm256_storeu_ps(pp + 8 * 2, _r2); + _mm256_storeu_ps(pp + 8 * 3, _r3); + _mm256_storeu_ps(pp + 8 * 4, _r4); + _mm256_storeu_ps(pp + 8 * 5, _r5); + _mm256_storeu_ps(pp + 8 * 6, _r6); + _mm256_storeu_ps(pp + 8 * 7, _r7); + p0 += max_jj * batch * 8; pp += 64; } - p0 -= jj * N + b * 8; + p0 -= (b * max_jj + jj) * 8; #endif // __AVX__ - N = batch * 4; - p0 += jj * N + b * 4; + p0 += (b * max_jj + jj) * 4; for (; kk + 3 < max_kk; kk += 4) { __m128 _r0 = _mm_load_ps(p0); - __m128 _r1 = _mm_load_ps(p0 + N); - __m128 _r2 = _mm_load_ps(p0 + 2 * N); - __m128 _r3 = _mm_load_ps(p0 + 3 * N); - __m128 _r4 = _mm_load_ps(p0 + 4 * N); - __m128 _r5 = _mm_load_ps(p0 + 5 * N); - __m128 _r6 = _mm_load_ps(p0 + 6 * N); - __m128 _r7 = _mm_load_ps(p0 + 7 * N); + __m128 _r1 = _mm_load_ps(p0 + 4); + __m128 _r2 = _mm_load_ps(p0 + 4 * 2); + __m128 _r3 = _mm_load_ps(p0 + 4 * 3); + __m128 _r4 = _mm_load_ps(p0 + 4 * 4); + __m128 _r5 = _mm_load_ps(p0 + 4 * 5); + __m128 _r6 = _mm_load_ps(p0 + 4 * 6); + __m128 _r7 = _mm_load_ps(p0 + 4 * 7); _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); _MM_TRANSPOSE4_PS(_r4, _r5, _r6, _r7); _mm_store_ps(pp, _r0); @@ -370,302 +358,272 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int batch, int max_jj, _mm_store_ps(pp + 20, _r6); _mm_store_ps(pp + 24, _r3); _mm_store_ps(pp + 28, _r7); - p0 += max_jj * N; + p0 += max_jj * batch * 4; pp += 32; } - p0 -= jj * N + b * 4; - N = batch * 2; - p0 += jj * N + b * 2; + p0 -= (b * max_jj + jj) * 4; + p0 += (b * max_jj + jj) * 2; for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; - pp[1] = p0[N]; - pp[2] = p0[2 * N]; - pp[3] = p0[3 * N]; - pp[4] = p0[4 * N]; - pp[5] = p0[5 * N]; - pp[6] = p0[6 * N]; - pp[7] = p0[7 * N]; + pp[1] = p0[2]; + pp[2] = p0[4]; + pp[3] = p0[6]; + pp[4] = p0[8]; + pp[5] = p0[10]; + pp[6] = p0[12]; + pp[7] = p0[14]; pp[8] = p0[1]; - pp[9] = p0[N + 1]; - pp[10] = p0[2 * N + 1]; - pp[11] = p0[3 * N + 1]; - pp[12] = p0[4 * N + 1]; - pp[13] = p0[5 * N + 1]; - pp[14] = p0[6 * N + 1]; - pp[15] = p0[7 * N + 1]; - p0 += max_jj * N; + pp[9] = p0[3]; + pp[10] = p0[5]; + pp[11] = p0[7]; + pp[12] = p0[9]; + pp[13] = p0[11]; + pp[14] = p0[13]; + pp[15] = p0[15]; + p0 += max_jj * batch * 2; pp += 16; } - p0 -= jj * N + b * 2; - N = batch; - p0 += jj * N + b; + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); for (; kk < max_kk; kk++) { pp[0] = p0[0]; - pp[1] = p0[N]; - pp[2] = p0[2 * N]; - pp[3] = p0[3 * N]; - pp[4] = p0[4 * N]; - pp[5] = p0[5 * N]; - pp[6] = p0[6 * N]; - pp[7] = p0[7 * N]; - p0 += max_jj * N; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p0[4]; + pp[5] = p0[5]; + pp[6] = p0[6]; + pp[7] = p0[7]; + p0 += max_jj * batch; pp += 8; } - p0 -= jj * N + b; } for (; jj + 3 < max_jj; jj += 4) { - int N = batch; const float* p0 = B; int kk = 0; #if __AVX__ #if __AVX512F__ - N = batch * 16; - p0 += jj * N + b * 16; + p0 += (b * max_jj + jj) * 16; for (; kk + 15 < max_kk; kk += 16) { __m512 _r0 = _mm512_load_ps(p0); - __m512 _r1 = _mm512_load_ps(p0 + N); - __m512 _r2 = _mm512_load_ps(p0 + 2 * N); - __m512 _r3 = _mm512_load_ps(p0 + 3 * N); + __m512 _r1 = _mm512_load_ps(p0 + 16); + __m512 _r2 = _mm512_load_ps(p0 + 16 * 2); + __m512 _r3 = _mm512_load_ps(p0 + 16 * 3); transpose16x4_ps(_r0, _r1, _r2, _r3); _mm512_storeu_ps(pp, _r0); _mm512_storeu_ps(pp + 16, _r1); _mm512_storeu_ps(pp + 32, _r2); _mm512_storeu_ps(pp + 48, _r3); - p0 += max_jj * N; + p0 += max_jj * batch * 16; pp += 64; } - p0 -= jj * N + b * 16; + p0 -= (b * max_jj + jj) * 16; #endif // __AVX512F__ - N = batch * 8; - p0 += jj * N + b * 8; + p0 += (b * max_jj + jj) * 8; for (; kk + 7 < max_kk; kk += 8) { __m256 _r0 = _mm256_load_ps(p0); - __m256 _r1 = _mm256_load_ps(p0 + N); - __m256 _r2 = _mm256_load_ps(p0 + 2 * N); - __m256 _r3 = _mm256_load_ps(p0 + 3 * N); + __m256 _r1 = _mm256_load_ps(p0 + 8); + __m256 _r2 = _mm256_load_ps(p0 + 8 * 2); + __m256 _r3 = _mm256_load_ps(p0 + 8 * 3); transpose8x4_ps(_r0, _r1, _r2, _r3); _mm256_storeu_ps(pp, _r0); _mm256_storeu_ps(pp + 8, _r1); - _mm256_storeu_ps(pp + 16, _r2); - _mm256_storeu_ps(pp + 24, _r3); - p0 += max_jj * N; + _mm256_storeu_ps(pp + 8 * 2, _r2); + _mm256_storeu_ps(pp + 8 * 3, _r3); + p0 += max_jj * batch * 8; pp += 32; } - p0 -= jj * N + b * 8; + p0 -= (b * max_jj + jj) * 8; #endif // __AVX__ - N = batch * 4; - p0 += jj * N + b * 4; + p0 += (b * max_jj + jj) * 4; for (; kk + 3 < max_kk; kk += 4) { __m128 _r0 = _mm_load_ps(p0); - __m128 _r1 = _mm_load_ps(p0 + N); - __m128 _r2 = _mm_load_ps(p0 + 2 * N); - __m128 _r3 = _mm_load_ps(p0 + 3 * N); + __m128 _r1 = _mm_load_ps(p0 + 4); + __m128 _r2 = _mm_load_ps(p0 + 4 * 2); + __m128 _r3 = _mm_load_ps(p0 + 4 * 3); _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); _mm_store_ps(pp, _r0); _mm_store_ps(pp + 4, _r1); _mm_store_ps(pp + 8, _r2); _mm_store_ps(pp + 12, _r3); - p0 += max_jj * N; + p0 += max_jj * batch * 4; pp += 16; } - p0 -= jj * N + b * 4; - N = batch * 2; - p0 += jj * N + b * 2; + p0 -= (b * max_jj + jj) * 4; + p0 += (b * max_jj + jj) * 2; for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; - pp[1] = p0[N]; - pp[2] = p0[2 * N]; - pp[3] = p0[3 * N]; + pp[1] = p0[2]; + pp[2] = p0[4]; + pp[3] = p0[6]; pp[4] = p0[1]; - pp[5] = p0[N + 1]; - pp[6] = p0[2 * N + 1]; - pp[7] = p0[3 * N + 1]; - p0 += max_jj * N; + pp[5] = p0[3]; + pp[6] = p0[5]; + pp[7] = p0[7]; + p0 += max_jj * batch * 2; pp += 8; } - p0 -= jj * N + b * 2; - N = batch; - p0 += jj * N + b; + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); for (; kk < max_kk; kk++) { pp[0] = p0[0]; - pp[1] = p0[N]; - pp[2] = p0[2 * N]; - pp[3] = p0[3 * N]; - p0 += max_jj * N; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + p0 += max_jj * batch; pp += 4; } - p0 -= jj * N + b; } #endif // __SSE2__ for (; jj + 1 < max_jj; jj += 2) { - int N = batch; const float* p0 = B; int kk = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ - N = batch * 16; - p0 += jj * N + b * 16; + p0 += (b * max_jj + jj) * 16; for (; kk + 15 < max_kk; kk += 16) { __m512 _r0 = _mm512_load_ps(p0); - __m512 _r1 = _mm512_load_ps(p0 + N); + __m512 _r1 = _mm512_load_ps(p0 + 16); transpose16x2_ps(_r0, _r1); _mm512_storeu_ps(pp, _r0); _mm512_storeu_ps(pp + 16, _r1); - p0 += max_jj * N; + p0 += max_jj * batch * 16; pp += 32; } - p0 -= jj * N + b * 16; + p0 -= (b * max_jj + jj) * 16; #endif // __AVX512F__ - N = batch * 8; - p0 += jj * N + b * 8; + p0 += (b * max_jj + jj) * 8; for (; kk + 7 < max_kk; kk += 8) { __m256 _r0 = _mm256_load_ps(p0); - __m256 _r1 = _mm256_load_ps(p0 + N); + __m256 _r1 = _mm256_load_ps(p0 + 8); transpose8x2_ps(_r0, _r1); _mm256_storeu_ps(pp, _r0); _mm256_storeu_ps(pp + 8, _r1); - p0 += max_jj * N; + p0 += max_jj * batch * 8; pp += 16; } - p0 -= jj * N + b * 8; + p0 -= (b * max_jj + jj) * 8; #endif // __AVX__ - N = batch * 4; - p0 += jj * N + b * 4; + p0 += (b * max_jj + jj) * 4; for (; kk + 3 < max_kk; kk += 4) { __m128 _r0 = _mm_load_ps(p0); - __m128 _r1 = _mm_load_ps(p0 + N); + __m128 _r1 = _mm_load_ps(p0 + 4); __m128 _tmp0 = _mm_unpacklo_ps(_r0, _r1); __m128 _tmp1 = _mm_unpackhi_ps(_r0, _r1); - _mm_store_ps(pp, _tmp0); - _mm_store_ps(pp + 4, _tmp1); - p0 += max_jj * N; + _mm_storeu_ps(pp, _tmp0); + _mm_storeu_ps(pp + 4, _tmp1); + p0 += max_jj * batch * 4; pp += 8; } - p0 -= jj * N + b * 4; + p0 -= (b * max_jj + jj) * 4; #endif // __SSE2__ - N = batch * 2; - p0 += jj * N + b * 2; + p0 += (b * max_jj + jj) * 2; for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; - pp[1] = p0[N]; + pp[1] = p0[2]; pp[2] = p0[1]; - pp[3] = p0[N + 1]; - p0 += max_jj * N; + pp[3] = p0[3]; + p0 += max_jj * batch * 2; pp += 4; } - p0 -= jj * N + b * 2; - N = batch; - p0 += jj * N + b; + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); for (; kk < max_kk; kk++) { pp[0] = p0[0]; - pp[1] = p0[N]; - p0 += max_jj * N; + pp[1] = p0[1]; + p0 += max_jj * batch; pp += 2; } - p0 -= jj * N + b; } for (; jj < max_jj; jj++) { - int N = batch; const float* p0 = B; int kk = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ - N = batch * 16; - p0 += jj * N + b * 16; + p0 += (b * max_jj + jj) * 16; for (; kk + 15 < max_kk; kk += 16) { __m512 _r0 = _mm512_load_ps(p0); _mm512_storeu_ps(pp, _r0); - p0 += max_jj * N; + p0 += max_jj * batch * 16; pp += 16; } - p0 -= jj * N + b * 16; + p0 -= (b * max_jj + jj) * 16; #endif // __AVX512F__ - N = batch * 8; - p0 += jj * N + b * 8; + p0 += (b * max_jj + jj) * 8; for (; kk + 7 < max_kk; kk += 8) { __m256 _r0 = _mm256_load_ps(p0); _mm256_storeu_ps(pp, _r0); - p0 += max_jj * N; + p0 += max_jj * batch * 8; pp += 8; } - p0 -= jj * N + b * 8; + p0 -= (b * max_jj + jj) * 8; #endif // __AVX__ - N = batch * 4; - p0 += jj * N + b * 4; + p0 += (b * max_jj + jj) * 4; for (; kk + 3 < max_kk; kk += 4) { __m128 _r0 = _mm_load_ps(p0); _mm_storeu_ps(pp, _r0); - p0 += max_jj * N; + p0 += max_jj * batch * 4; pp += 4; } - p0 -= jj * N + b * 4; + p0 -= (b * max_jj + jj) * 4; #endif // __SSE2__ - N = batch * 2; - p0 += jj * N + b * 2; + p0 += (b * max_jj + jj) * 2; for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; pp[1] = p0[1]; - p0 += max_jj * N; + p0 += max_jj * batch * 2; pp += 2; } - p0 -= jj * N + b * 2; - N = batch; - p0 += jj * N + b; + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); for (; kk < max_kk; kk++) { pp[0] = p0[0]; - p0 += max_jj * N; + p0 += max_jj * batch; pp += 1; } - p0 -= jj * N + b; } } } -static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& top_blob, Mat& tmp, int batch, int max_ii, int max_jj, int k, int max_kk, bool k_end) +static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& top_blob, int batch, int max_ii, int max_jj, int k, int max_kk) { - const int TILE_M = top_blob.w; + float* outptr = top_blob; - for (int b = 0; b < batch; b++) - { - const float* pAT = AT_tile.row(b); - const float* pBT = BT_tile.row(b); - - float* ptmp = tmp.row(b); - - int ii = 0; + int ii = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ - for (; ii + 15 < max_ii; ii += 16) + for (; ii + 15 < max_ii; ii += 16) + { + for (int b = 0; b < batch; b++) { - const float* pB = pBT; - - float* outptr = (float*)top_blob.depth(b) + ii; + const float* pAT = AT_tile.row(b) + max_kk * ii; + const float* pB = BT_tile.row(b); int jj = 0; for (; jj + 11 < max_jj; jj += 12) @@ -702,18 +660,18 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum0 = _mm512_load_ps(ptmp); - _sum1 = _mm512_load_ps(ptmp + 16); - _sum2 = _mm512_load_ps(ptmp + 16 * 2); - _sum3 = _mm512_load_ps(ptmp + 16 * 3); - _sum4 = _mm512_load_ps(ptmp + 16 * 4); - _sum5 = _mm512_load_ps(ptmp + 16 * 5); - _sum6 = _mm512_load_ps(ptmp + 16 * 6); - _sum7 = _mm512_load_ps(ptmp + 16 * 7); - _sum8 = _mm512_load_ps(ptmp + 16 * 8); - _sum9 = _mm512_load_ps(ptmp + 16 * 9); - _suma = _mm512_load_ps(ptmp + 16 * 10); - _sumb = _mm512_load_ps(ptmp + 16 * 11); + _sum0 = _mm512_load_ps(outptr); + _sum1 = _mm512_load_ps(outptr + 16); + _sum2 = _mm512_load_ps(outptr + 16 * 2); + _sum3 = _mm512_load_ps(outptr + 16 * 3); + _sum4 = _mm512_load_ps(outptr + 16 * 4); + _sum5 = _mm512_load_ps(outptr + 16 * 5); + _sum6 = _mm512_load_ps(outptr + 16 * 6); + _sum7 = _mm512_load_ps(outptr + 16 * 7); + _sum8 = _mm512_load_ps(outptr + 16 * 8); + _sum9 = _mm512_load_ps(outptr + 16 * 9); + _suma = _mm512_load_ps(outptr + 16 * 10); + _sumb = _mm512_load_ps(outptr + 16 * 11); } int kk = 0; @@ -737,39 +695,19 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 12; } - if (k_end) - { - _mm512_store_ps(outptr, _sum0); - _mm512_store_ps(outptr + TILE_M, _sum1); - _mm512_store_ps(outptr + TILE_M * 2, _sum2); - _mm512_store_ps(outptr + TILE_M * 3, _sum3); - _mm512_store_ps(outptr + TILE_M * 4, _sum4); - _mm512_store_ps(outptr + TILE_M * 5, _sum5); - _mm512_store_ps(outptr + TILE_M * 6, _sum6); - _mm512_store_ps(outptr + TILE_M * 7, _sum7); - _mm512_store_ps(outptr + TILE_M * 8, _sum8); - _mm512_store_ps(outptr + TILE_M * 9, _sum9); - _mm512_store_ps(outptr + TILE_M * 10, _suma); - _mm512_store_ps(outptr + TILE_M * 11, _sumb); - outptr += TILE_M * 12; - } - else - { - _mm512_store_ps(ptmp, _sum0); - _mm512_store_ps(ptmp + 16, _sum1); - _mm512_store_ps(ptmp + 16 * 2, _sum2); - _mm512_store_ps(ptmp + 16 * 3, _sum3); - _mm512_store_ps(ptmp + 16 * 4, _sum4); - _mm512_store_ps(ptmp + 16 * 5, _sum5); - _mm512_store_ps(ptmp + 16 * 6, _sum6); - _mm512_store_ps(ptmp + 16 * 7, _sum7); - _mm512_store_ps(ptmp + 16 * 8, _sum8); - _mm512_store_ps(ptmp + 16 * 9, _sum9); - _mm512_store_ps(ptmp + 16 * 10, _suma); - _mm512_store_ps(ptmp + 16 * 11, _sumb); - } - - ptmp += 192; + _mm512_store_ps(outptr, _sum0); + _mm512_store_ps(outptr + 16, _sum1); + _mm512_store_ps(outptr + 16 * 2, _sum2); + _mm512_store_ps(outptr + 16 * 3, _sum3); + _mm512_store_ps(outptr + 16 * 4, _sum4); + _mm512_store_ps(outptr + 16 * 5, _sum5); + _mm512_store_ps(outptr + 16 * 6, _sum6); + _mm512_store_ps(outptr + 16 * 7, _sum7); + _mm512_store_ps(outptr + 16 * 8, _sum8); + _mm512_store_ps(outptr + 16 * 9, _sum9); + _mm512_store_ps(outptr + 16 * 10, _suma); + _mm512_store_ps(outptr + 16 * 11, _sumb); + outptr += 16 * 12; } for (; jj + 7 < max_jj; jj += 8) { @@ -797,14 +735,14 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum0 = _mm512_load_ps(ptmp); - _sum1 = _mm512_load_ps(ptmp + 16); - _sum2 = _mm512_load_ps(ptmp + 16 * 2); - _sum3 = _mm512_load_ps(ptmp + 16 * 3); - _sum4 = _mm512_load_ps(ptmp + 16 * 4); - _sum5 = _mm512_load_ps(ptmp + 16 * 5); - _sum6 = _mm512_load_ps(ptmp + 16 * 6); - _sum7 = _mm512_load_ps(ptmp + 16 * 7); + _sum0 = _mm512_load_ps(outptr); + _sum1 = _mm512_load_ps(outptr + 16); + _sum2 = _mm512_load_ps(outptr + 16 * 2); + _sum3 = _mm512_load_ps(outptr + 16 * 3); + _sum4 = _mm512_load_ps(outptr + 16 * 4); + _sum5 = _mm512_load_ps(outptr + 16 * 5); + _sum6 = _mm512_load_ps(outptr + 16 * 6); + _sum7 = _mm512_load_ps(outptr + 16 * 7); } int kk = 0; @@ -824,31 +762,15 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 8; } - if (k_end) - { - _mm512_store_ps(outptr, _sum0); - _mm512_store_ps(outptr + TILE_M, _sum1); - _mm512_store_ps(outptr + TILE_M * 2, _sum2); - _mm512_store_ps(outptr + TILE_M * 3, _sum3); - _mm512_store_ps(outptr + TILE_M * 4, _sum4); - _mm512_store_ps(outptr + TILE_M * 5, _sum5); - _mm512_store_ps(outptr + TILE_M * 6, _sum6); - _mm512_store_ps(outptr + TILE_M * 7, _sum7); - outptr += TILE_M * 8; - } - else - { - _mm512_store_ps(ptmp, _sum0); - _mm512_store_ps(ptmp + 16, _sum1); - _mm512_store_ps(ptmp + 16 * 2, _sum2); - _mm512_store_ps(ptmp + 16 * 3, _sum3); - _mm512_store_ps(ptmp + 16 * 4, _sum4); - _mm512_store_ps(ptmp + 16 * 5, _sum5); - _mm512_store_ps(ptmp + 16 * 6, _sum6); - _mm512_store_ps(ptmp + 16 * 7, _sum7); - } - - ptmp += 128; + _mm512_store_ps(outptr, _sum0); + _mm512_store_ps(outptr + 16, _sum1); + _mm512_store_ps(outptr + 16 * 2, _sum2); + _mm512_store_ps(outptr + 16 * 3, _sum3); + _mm512_store_ps(outptr + 16 * 4, _sum4); + _mm512_store_ps(outptr + 16 * 5, _sum5); + _mm512_store_ps(outptr + 16 * 6, _sum6); + _mm512_store_ps(outptr + 16 * 7, _sum7); + outptr += 16 * 8; } for (; jj + 3 < max_jj; jj += 4) { @@ -868,10 +790,10 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum0 = _mm512_load_ps(ptmp); - _sum1 = _mm512_load_ps(ptmp + 16); - _sum2 = _mm512_load_ps(ptmp + 16 * 2); - _sum3 = _mm512_load_ps(ptmp + 16 * 3); + _sum0 = _mm512_load_ps(outptr); + _sum1 = _mm512_load_ps(outptr + 16); + _sum2 = _mm512_load_ps(outptr + 16 * 2); + _sum3 = _mm512_load_ps(outptr + 16 * 3); } int kk = 0; @@ -887,23 +809,11 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 4; } - if (k_end) - { - _mm512_store_ps(outptr, _sum0); - _mm512_store_ps(outptr + TILE_M * 1, _sum1); - _mm512_store_ps(outptr + TILE_M * 2, _sum2); - _mm512_store_ps(outptr + TILE_M * 3, _sum3); - outptr += TILE_M * 4; - } - else - { - _mm512_store_ps(ptmp, _sum0); - _mm512_store_ps(ptmp + 16, _sum1); - _mm512_store_ps(ptmp + 16 * 2, _sum2); - _mm512_store_ps(ptmp + 16 * 3, _sum3); - } - - ptmp += 64; + _mm512_store_ps(outptr, _sum0); + _mm512_store_ps(outptr + 16, _sum1); + _mm512_store_ps(outptr + 16 * 2, _sum2); + _mm512_store_ps(outptr + 16 * 3, _sum3); + outptr += 16 * 4; } for (; jj + 1 < max_jj; jj += 2) { @@ -919,8 +829,8 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum0 = _mm512_load_ps(ptmp); - _sum1 = _mm512_load_ps(ptmp + 8); + _sum0 = _mm512_load_ps(outptr); + _sum1 = _mm512_load_ps(outptr + 16); } int kk = 0; @@ -934,19 +844,9 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 2; } - if (k_end) - { - _mm512_store_ps(outptr, _sum0); - _mm512_store_ps(outptr + TILE_M, _sum1); - outptr += TILE_M * 2; - } - else - { - _mm512_store_ps(ptmp, _sum0); - _mm512_store_ps(ptmp + 16, _sum1); - } - - ptmp += 32; + _mm512_store_ps(outptr, _sum0); + _mm512_store_ps(outptr + 16, _sum1); + outptr += 16 * 2; } for (; jj < max_jj; jj++) { @@ -960,7 +860,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum = _mm512_load_ps(ptmp); + _sum = _mm512_load_ps(outptr); } int kk = 0; @@ -974,27 +874,18 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 1; } - if (k_end) - { - _mm512_store_ps(outptr, _sum); - outptr += TILE_M; - } - else - { - _mm512_store_ps(ptmp, _sum); - } - - ptmp += 16; + _mm512_store_ps(outptr, _sum); + outptr += 16; } - - pAT += max_kk * 16; } + } #endif // __AVX512F__ - for (; ii + 7 < max_ii; ii += 8) + for (; ii + 7 < max_ii; ii += 8) + { + for (int b = 0; b < batch; b++) { - const float* pB = pBT; - - float* outptr = (float*)top_blob.depth(b) + ii; + const float* pAT = AT_tile.row(b) + max_kk * ii; + const float* pB = BT_tile.row(b); int jj = 0; for (; jj + 11 < max_jj; jj += 12) @@ -1031,18 +922,18 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum0 = _mm256_load_ps(ptmp); - _sum1 = _mm256_load_ps(ptmp + 8); - _sum2 = _mm256_load_ps(ptmp + 16); - _sum3 = _mm256_load_ps(ptmp + 24); - _sum4 = _mm256_load_ps(ptmp + 32); - _sum5 = _mm256_load_ps(ptmp + 40); - _sum6 = _mm256_load_ps(ptmp + 48); - _sum7 = _mm256_load_ps(ptmp + 56); - _sum8 = _mm256_load_ps(ptmp + 64); - _sum9 = _mm256_load_ps(ptmp + 72); - _suma = _mm256_load_ps(ptmp + 80); - _sumb = _mm256_load_ps(ptmp + 88); + _sum0 = _mm256_load_ps(outptr); + _sum1 = _mm256_load_ps(outptr + 8); + _sum2 = _mm256_load_ps(outptr + 16); + _sum3 = _mm256_load_ps(outptr + 24); + _sum4 = _mm256_load_ps(outptr + 32); + _sum5 = _mm256_load_ps(outptr + 40); + _sum6 = _mm256_load_ps(outptr + 48); + _sum7 = _mm256_load_ps(outptr + 56); + _sum8 = _mm256_load_ps(outptr + 64); + _sum9 = _mm256_load_ps(outptr + 72); + _suma = _mm256_load_ps(outptr + 80); + _sumb = _mm256_load_ps(outptr + 88); } int kk = 0; @@ -1066,38 +957,19 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 12; } - if (k_end) - { - _mm256_store_ps(outptr, _sum0); - _mm256_store_ps(outptr + TILE_M, _sum1); - _mm256_store_ps(outptr + TILE_M * 2, _sum2); - _mm256_store_ps(outptr + TILE_M * 3, _sum3); - _mm256_store_ps(outptr + TILE_M * 4, _sum4); - _mm256_store_ps(outptr + TILE_M * 5, _sum5); - _mm256_store_ps(outptr + TILE_M * 6, _sum6); - _mm256_store_ps(outptr + TILE_M * 7, _sum7); - _mm256_store_ps(outptr + TILE_M * 8, _sum8); - _mm256_store_ps(outptr + TILE_M * 9, _sum9); - _mm256_store_ps(outptr + TILE_M * 10, _suma); - _mm256_store_ps(outptr + TILE_M * 11, _sumb); - outptr += TILE_M * 12; - } - else - { - _mm256_store_ps(ptmp, _sum0); - _mm256_store_ps(ptmp + 8, _sum1); - _mm256_store_ps(ptmp + 16, _sum2); - _mm256_store_ps(ptmp + 24, _sum3); - _mm256_store_ps(ptmp + 32, _sum4); - _mm256_store_ps(ptmp + 40, _sum5); - _mm256_store_ps(ptmp + 48, _sum6); - _mm256_store_ps(ptmp + 56, _sum7); - _mm256_store_ps(ptmp + 64, _sum8); - _mm256_store_ps(ptmp + 72, _sum9); - _mm256_store_ps(ptmp + 80, _suma); - _mm256_store_ps(ptmp + 88, _sumb); - } - ptmp += 96; + _mm256_store_ps(outptr, _sum0); + _mm256_store_ps(outptr + 8, _sum1); + _mm256_store_ps(outptr + 8 * 2, _sum2); + _mm256_store_ps(outptr + 8 * 3, _sum3); + _mm256_store_ps(outptr + 8 * 4, _sum4); + _mm256_store_ps(outptr + 8 * 5, _sum5); + _mm256_store_ps(outptr + 8 * 6, _sum6); + _mm256_store_ps(outptr + 8 * 7, _sum7); + _mm256_store_ps(outptr + 8 * 8, _sum8); + _mm256_store_ps(outptr + 8 * 9, _sum9); + _mm256_store_ps(outptr + 8 * 10, _suma); + _mm256_store_ps(outptr + 8 * 11, _sumb); + outptr += 8 * 12; } for (; jj + 7 < max_jj; jj += 8) { @@ -1125,14 +997,14 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum0 = _mm256_load_ps(ptmp); - _sum1 = _mm256_load_ps(ptmp + 8); - _sum2 = _mm256_load_ps(ptmp + 16); - _sum3 = _mm256_load_ps(ptmp + 24); - _sum4 = _mm256_load_ps(ptmp + 32); - _sum5 = _mm256_load_ps(ptmp + 40); - _sum6 = _mm256_load_ps(ptmp + 48); - _sum7 = _mm256_load_ps(ptmp + 56); + _sum0 = _mm256_load_ps(outptr); + _sum1 = _mm256_load_ps(outptr + 8); + _sum2 = _mm256_load_ps(outptr + 16); + _sum3 = _mm256_load_ps(outptr + 24); + _sum4 = _mm256_load_ps(outptr + 32); + _sum5 = _mm256_load_ps(outptr + 40); + _sum6 = _mm256_load_ps(outptr + 48); + _sum7 = _mm256_load_ps(outptr + 56); } int kk = 0; @@ -1152,31 +1024,15 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 8; } - if (k_end) - { - _mm256_store_ps(outptr, _sum0); - _mm256_store_ps(outptr + TILE_M, _sum1); - _mm256_store_ps(outptr + TILE_M * 2, _sum2); - _mm256_store_ps(outptr + TILE_M * 3, _sum3); - _mm256_store_ps(outptr + TILE_M * 4, _sum4); - _mm256_store_ps(outptr + TILE_M * 5, _sum5); - _mm256_store_ps(outptr + TILE_M * 6, _sum6); - _mm256_store_ps(outptr + TILE_M * 7, _sum7); - outptr += TILE_M * 8; - } - else - { - _mm256_store_ps(ptmp, _sum0); - _mm256_store_ps(ptmp + 8, _sum1); - _mm256_store_ps(ptmp + 16, _sum2); - _mm256_store_ps(ptmp + 24, _sum3); - _mm256_store_ps(ptmp + 32, _sum4); - _mm256_store_ps(ptmp + 40, _sum5); - _mm256_store_ps(ptmp + 48, _sum6); - _mm256_store_ps(ptmp + 56, _sum7); - } - - ptmp += 64; + _mm256_store_ps(outptr, _sum0); + _mm256_store_ps(outptr + 8, _sum1); + _mm256_store_ps(outptr + 8 * 2, _sum2); + _mm256_store_ps(outptr + 8 * 3, _sum3); + _mm256_store_ps(outptr + 8 * 4, _sum4); + _mm256_store_ps(outptr + 8 * 5, _sum5); + _mm256_store_ps(outptr + 8 * 6, _sum6); + _mm256_store_ps(outptr + 8 * 7, _sum7); + outptr += 8 * 8; } for (; jj + 3 < max_jj; jj += 4) { @@ -1196,10 +1052,10 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum0 = _mm256_load_ps(ptmp); - _sum1 = _mm256_load_ps(ptmp + 8); - _sum2 = _mm256_load_ps(ptmp + 16); - _sum3 = _mm256_load_ps(ptmp + 24); + _sum0 = _mm256_load_ps(outptr); + _sum1 = _mm256_load_ps(outptr + 8); + _sum2 = _mm256_load_ps(outptr + 16); + _sum3 = _mm256_load_ps(outptr + 24); } int kk = 0; @@ -1215,23 +1071,11 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 4; } - if (k_end) - { - _mm256_store_ps(outptr, _sum0); - _mm256_store_ps(outptr + TILE_M, _sum1); - _mm256_store_ps(outptr + TILE_M * 2, _sum2); - _mm256_store_ps(outptr + TILE_M * 3, _sum3); - outptr += TILE_M * 4; - } - else - { - _mm256_store_ps(ptmp, _sum0); - _mm256_store_ps(ptmp + 8, _sum1); - _mm256_store_ps(ptmp + 16, _sum2); - _mm256_store_ps(ptmp + 24, _sum3); - } - - ptmp += 32; + _mm256_store_ps(outptr, _sum0); + _mm256_store_ps(outptr + 8, _sum1); + _mm256_store_ps(outptr + 8 * 2, _sum2); + _mm256_store_ps(outptr + 8 * 3, _sum3); + outptr += 8 * 4; } for (; jj + 1 < max_jj; jj += 2) { @@ -1247,8 +1091,8 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum0 = _mm256_load_ps(ptmp); - _sum1 = _mm256_load_ps(ptmp + 8); + _sum0 = _mm256_load_ps(outptr); + _sum1 = _mm256_load_ps(outptr + 8); } int kk = 0; @@ -1262,19 +1106,9 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 2; } - if (k_end) - { - _mm256_store_ps(outptr, _sum0); - _mm256_store_ps(outptr + TILE_M, _sum1); - outptr += TILE_M * 2; - } - else - { - _mm256_store_ps(ptmp, _sum0); - _mm256_store_ps(ptmp + 8, _sum1); - } - - ptmp += 16; + _mm256_store_ps(outptr, _sum0); + _mm256_store_ps(outptr + 8, _sum1); + outptr += 8 * 2; } for (; jj < max_jj; jj++) { @@ -1288,7 +1122,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum = _mm256_load_ps(ptmp); + _sum = _mm256_load_ps(outptr); } int kk = 0; @@ -1302,27 +1136,18 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 1; } - if (k_end) - { - _mm256_store_ps(outptr, _sum); - outptr += TILE_M; - } - else - { - _mm256_store_ps(ptmp, _sum); - } - - ptmp += 8; + _mm256_store_ps(outptr, _sum); + outptr += 8; } - - pAT += max_kk * 8; } + } #endif // __AVX__ - for (; ii + 3 < max_ii; ii += 4) + for (; ii + 3 < max_ii; ii += 4) + { + for (int b = 0; b < batch; b++) { - const float* pB = pBT; - - float* outptr = (float*)top_blob.depth(b) + ii; + const float* pAT = AT_tile.row(b) + max_kk * ii; + const float* pB = BT_tile.row(b); int jj = 0; for (; jj + 11 < max_jj; jj += 12) @@ -1359,18 +1184,18 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum0 = _mm_load_ps(ptmp); - _sum1 = _mm_load_ps(ptmp + 4); - _sum2 = _mm_load_ps(ptmp + 8); - _sum3 = _mm_load_ps(ptmp + 12); - _sum4 = _mm_load_ps(ptmp + 16); - _sum5 = _mm_load_ps(ptmp + 20); - _sum6 = _mm_load_ps(ptmp + 24); - _sum7 = _mm_load_ps(ptmp + 28); - _sum8 = _mm_load_ps(ptmp + 32); - _sum9 = _mm_load_ps(ptmp + 36); - _suma = _mm_load_ps(ptmp + 40); - _sumb = _mm_load_ps(ptmp + 44); + _sum0 = _mm_load_ps(outptr); + _sum1 = _mm_load_ps(outptr + 4); + _sum2 = _mm_load_ps(outptr + 8); + _sum3 = _mm_load_ps(outptr + 12); + _sum4 = _mm_load_ps(outptr + 16); + _sum5 = _mm_load_ps(outptr + 20); + _sum6 = _mm_load_ps(outptr + 24); + _sum7 = _mm_load_ps(outptr + 28); + _sum8 = _mm_load_ps(outptr + 32); + _sum9 = _mm_load_ps(outptr + 36); + _suma = _mm_load_ps(outptr + 40); + _sumb = _mm_load_ps(outptr + 44); } int kk = 0; @@ -1394,39 +1219,19 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 12; } - if (k_end) - { - _mm_store_ps(outptr, _sum0); - _mm_store_ps(outptr + TILE_M, _sum1); - _mm_store_ps(outptr + TILE_M * 2, _sum2); - _mm_store_ps(outptr + TILE_M * 3, _sum3); - _mm_store_ps(outptr + TILE_M * 4, _sum4); - _mm_store_ps(outptr + TILE_M * 5, _sum5); - _mm_store_ps(outptr + TILE_M * 6, _sum6); - _mm_store_ps(outptr + TILE_M * 7, _sum7); - _mm_store_ps(outptr + TILE_M * 8, _sum8); - _mm_store_ps(outptr + TILE_M * 9, _sum9); - _mm_store_ps(outptr + TILE_M * 10, _suma); - _mm_store_ps(outptr + TILE_M * 11, _sumb); - outptr += TILE_M * 12; - } - else - { - _mm_store_ps(ptmp, _sum0); - _mm_store_ps(ptmp + 4, _sum1); - _mm_store_ps(ptmp + 8, _sum2); - _mm_store_ps(ptmp + 12, _sum3); - _mm_store_ps(ptmp + 16, _sum4); - _mm_store_ps(ptmp + 20, _sum5); - _mm_store_ps(ptmp + 24, _sum6); - _mm_store_ps(ptmp + 28, _sum7); - _mm_store_ps(ptmp + 32, _sum8); - _mm_store_ps(ptmp + 36, _sum9); - _mm_store_ps(ptmp + 40, _suma); - _mm_store_ps(ptmp + 44, _sumb); - } - - ptmp += 48; + _mm_store_ps(outptr, _sum0); + _mm_store_ps(outptr + 4, _sum1); + _mm_store_ps(outptr + 4 * 2, _sum2); + _mm_store_ps(outptr + 4 * 3, _sum3); + _mm_store_ps(outptr + 4 * 4, _sum4); + _mm_store_ps(outptr + 4 * 5, _sum5); + _mm_store_ps(outptr + 4 * 6, _sum6); + _mm_store_ps(outptr + 4 * 7, _sum7); + _mm_store_ps(outptr + 4 * 8, _sum8); + _mm_store_ps(outptr + 4 * 9, _sum9); + _mm_store_ps(outptr + 4 * 10, _suma); + _mm_store_ps(outptr + 4 * 11, _sumb); + outptr += 4 * 12; } for (; jj + 7 < max_jj; jj += 8) { @@ -1454,14 +1259,14 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum0 = _mm_load_ps(ptmp); - _sum1 = _mm_load_ps(ptmp + 4); - _sum2 = _mm_load_ps(ptmp + 8); - _sum3 = _mm_load_ps(ptmp + 12); - _sum4 = _mm_load_ps(ptmp + 16); - _sum5 = _mm_load_ps(ptmp + 20); - _sum6 = _mm_load_ps(ptmp + 24); - _sum7 = _mm_load_ps(ptmp + 28); + _sum0 = _mm_load_ps(outptr); + _sum1 = _mm_load_ps(outptr + 4); + _sum2 = _mm_load_ps(outptr + 8); + _sum3 = _mm_load_ps(outptr + 12); + _sum4 = _mm_load_ps(outptr + 16); + _sum5 = _mm_load_ps(outptr + 20); + _sum6 = _mm_load_ps(outptr + 24); + _sum7 = _mm_load_ps(outptr + 28); } int kk = 0; @@ -1481,31 +1286,15 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 8; } - if (k_end) - { - _mm_store_ps(outptr, _sum0); - _mm_store_ps(outptr + TILE_M * 1, _sum1); - _mm_store_ps(outptr + TILE_M * 2, _sum2); - _mm_store_ps(outptr + TILE_M * 3, _sum3); - _mm_store_ps(outptr + TILE_M * 4, _sum4); - _mm_store_ps(outptr + TILE_M * 5, _sum5); - _mm_store_ps(outptr + TILE_M * 6, _sum6); - _mm_store_ps(outptr + TILE_M * 7, _sum7); - outptr += TILE_M * 8; - } - else - { - _mm_store_ps(ptmp, _sum0); - _mm_store_ps(ptmp + 4, _sum1); - _mm_store_ps(ptmp + 8, _sum2); - _mm_store_ps(ptmp + 12, _sum3); - _mm_store_ps(ptmp + 16, _sum4); - _mm_store_ps(ptmp + 20, _sum5); - _mm_store_ps(ptmp + 24, _sum6); - _mm_store_ps(ptmp + 28, _sum7); - } - - ptmp += 32; + _mm_store_ps(outptr, _sum0); + _mm_store_ps(outptr + 4, _sum1); + _mm_store_ps(outptr + 4 * 2, _sum2); + _mm_store_ps(outptr + 4 * 3, _sum3); + _mm_store_ps(outptr + 4 * 4, _sum4); + _mm_store_ps(outptr + 4 * 5, _sum5); + _mm_store_ps(outptr + 4 * 6, _sum6); + _mm_store_ps(outptr + 4 * 7, _sum7); + outptr += 4 * 8; } for (; jj + 3 < max_jj; jj += 4) { @@ -1525,10 +1314,10 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum0 = _mm_load_ps(ptmp); - _sum1 = _mm_load_ps(ptmp + 4); - _sum2 = _mm_load_ps(ptmp + 8); - _sum3 = _mm_load_ps(ptmp + 12); + _sum0 = _mm_load_ps(outptr); + _sum1 = _mm_load_ps(outptr + 4); + _sum2 = _mm_load_ps(outptr + 8); + _sum3 = _mm_load_ps(outptr + 12); } int kk = 0; @@ -1544,23 +1333,11 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 4; } - if (k_end) - { - _mm_store_ps(outptr, _sum0); - _mm_store_ps(outptr + TILE_M, _sum1); - _mm_store_ps(outptr + TILE_M * 2, _sum2); - _mm_store_ps(outptr + TILE_M * 3, _sum3); - outptr += TILE_M * 4; - } - else - { - _mm_store_ps(ptmp, _sum0); - _mm_store_ps(ptmp + 4, _sum1); - _mm_store_ps(ptmp + 8, _sum2); - _mm_store_ps(ptmp + 12, _sum3); - } - - ptmp += 16; + _mm_store_ps(outptr, _sum0); + _mm_store_ps(outptr + 4, _sum1); + _mm_store_ps(outptr + 4 * 2, _sum2); + _mm_store_ps(outptr + 4 * 3, _sum3); + outptr += 4 * 4; } for (; jj + 1 < max_jj; jj += 2) { @@ -1576,8 +1353,8 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum0 = _mm_load_ps(ptmp); - _sum1 = _mm_load_ps(ptmp + 4); + _sum0 = _mm_load_ps(outptr); + _sum1 = _mm_load_ps(outptr + 4); } int kk = 0; @@ -1591,19 +1368,9 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 2; } - if (k_end) - { - _mm_store_ps(outptr, _sum0); - _mm_store_ps(outptr + TILE_M, _sum1); - outptr += TILE_M * 2; - } - else - { - _mm_store_ps(ptmp, _sum0); - _mm_store_ps(ptmp + 4, _sum1); - } - - ptmp += 8; + _mm_store_ps(outptr, _sum0); + _mm_store_ps(outptr + 4, _sum1); + outptr += 4 * 2; } for (; jj < max_jj; jj++) { @@ -1617,7 +1384,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum = _mm_load_ps(ptmp); + _sum = _mm_load_ps(outptr); } int kk = 0; @@ -1630,27 +1397,18 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 1; } - if (k_end) - { - _mm_store_ps(outptr, _sum); - outptr += TILE_M; - } - else - { - _mm_store_ps(ptmp, _sum); - } - - ptmp += 4; + _mm_store_ps(outptr, _sum); + outptr += 4; } - - pAT += max_kk * 4; } + } #endif // __SSE2__ - for (; ii + 1 < max_ii; ii += 2) + for (; ii + 1 < max_ii; ii += 2) + { + for (int b = 0; b < batch; b++) { - const float* pB = pBT; - - float* outptr = (float*)top_blob.depth(b) + ii; + const float* pAT = AT_tile.row(b) + max_kk * ii; + const float* pB = BT_tile.row(b); int jj = 0; #if __SSE2__ @@ -1676,12 +1434,18 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum0 = _mm_load_ps(ptmp); - _sum1 = _mm_load_ps(ptmp + 4); - _sum2 = _mm_load_ps(ptmp + 8); - _sum3 = _mm_load_ps(ptmp + 12); - _sum4 = _mm_load_ps(ptmp + 16); - _sum5 = _mm_load_ps(ptmp + 20); + __m128 _tmp0 = _mm_loadu_ps(outptr); + __m128 _tmp1 = _mm_loadu_ps(outptr + 4); + __m128 _tmp2 = _mm_loadu_ps(outptr + 8); + __m128 _tmp3 = _mm_loadu_ps(outptr + 12); + __m128 _tmp4 = _mm_loadu_ps(outptr + 16); + __m128 _tmp5 = _mm_loadu_ps(outptr + 20); + _sum0 = _mm_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm_shuffle_ps(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum2 = _mm_shuffle_ps(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _sum3 = _mm_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _sum4 = _mm_shuffle_ps(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _sum5 = _mm_shuffle_ps(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); } int kk = 0; @@ -1702,52 +1466,19 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 12; } - if (k_end) - { - float sum[24]; - _mm_storeu_ps(sum, _sum0); - _mm_storeu_ps(sum + 4, _sum1); - _mm_storeu_ps(sum + 8, _sum2); - _mm_storeu_ps(sum + 12, _sum3); - _mm_storeu_ps(sum + 16, _sum4); - _mm_storeu_ps(sum + 20, _sum5); - outptr[0] = sum[0]; - outptr[TILE_M] = sum[1]; - outptr[TILE_M * 2] = sum[2]; - outptr[TILE_M * 3] = sum[3]; - outptr[TILE_M * 4] = sum[4]; - outptr[TILE_M * 5] = sum[5]; - outptr[TILE_M * 6] = sum[6]; - outptr[TILE_M * 7] = sum[7]; - outptr[TILE_M * 8] = sum[8]; - outptr[TILE_M * 9] = sum[9]; - outptr[TILE_M * 10] = sum[10]; - outptr[TILE_M * 11] = sum[11]; - outptr[1] = sum[12]; - outptr[TILE_M + 1] = sum[13]; - outptr[TILE_M * 2 + 1] = sum[14]; - outptr[TILE_M * 3 + 1] = sum[15]; - outptr[TILE_M * 4 + 1] = sum[16]; - outptr[TILE_M * 5 + 1] = sum[17]; - outptr[TILE_M * 6 + 1] = sum[18]; - outptr[TILE_M * 7 + 1] = sum[19]; - outptr[TILE_M * 8 + 1] = sum[20]; - outptr[TILE_M * 9 + 1] = sum[21]; - outptr[TILE_M * 10 + 1] = sum[22]; - outptr[TILE_M * 11 + 1] = sum[23]; - outptr += TILE_M * 12; - } - else - { - _mm_store_ps(ptmp, _sum0); - _mm_store_ps(ptmp + 4, _sum1); - _mm_store_ps(ptmp + 8, _sum2); - _mm_store_ps(ptmp + 12, _sum3); - _mm_store_ps(ptmp + 16, _sum4); - _mm_store_ps(ptmp + 20, _sum5); - } - - ptmp += 24; + __m128 _tmp0 = _mm_unpacklo_ps(_sum0, _sum3); + __m128 _tmp1 = _mm_unpackhi_ps(_sum0, _sum3); + __m128 _tmp2 = _mm_unpacklo_ps(_sum1, _sum4); + __m128 _tmp3 = _mm_unpackhi_ps(_sum1, _sum4); + __m128 _tmp4 = _mm_unpacklo_ps(_sum2, _sum5); + __m128 _tmp5 = _mm_unpackhi_ps(_sum2, _sum5); + _mm_storeu_ps(outptr, _tmp0); + _mm_storeu_ps(outptr + 4, _tmp1); + _mm_storeu_ps(outptr + 8, _tmp2); + _mm_storeu_ps(outptr + 12, _tmp3); + _mm_storeu_ps(outptr + 16, _tmp4); + _mm_storeu_ps(outptr + 20, _tmp5); + outptr += 2 * 12; } for (; jj + 7 < max_jj; jj += 8) { @@ -1767,10 +1498,14 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum0 = _mm_load_ps(ptmp); - _sum1 = _mm_load_ps(ptmp + 4); - _sum2 = _mm_load_ps(ptmp + 8); - _sum3 = _mm_load_ps(ptmp + 12); + __m128 _tmp0 = _mm_loadu_ps(outptr); + __m128 _tmp1 = _mm_loadu_ps(outptr + 4); + __m128 _tmp2 = _mm_loadu_ps(outptr + 8); + __m128 _tmp3 = _mm_loadu_ps(outptr + 12); + _sum0 = _mm_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm_shuffle_ps(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum2 = _mm_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _sum3 = _mm_shuffle_ps(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); } int kk = 0; @@ -1788,40 +1523,15 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 8; } - if (k_end) - { - float sum[16]; - _mm_storeu_ps(sum, _sum0); - _mm_storeu_ps(sum + 4, _sum1); - _mm_storeu_ps(sum + 8, _sum2); - _mm_storeu_ps(sum + 12, _sum3); - outptr[0] = sum[0]; - outptr[TILE_M] = sum[1]; - outptr[TILE_M * 2] = sum[2]; - outptr[TILE_M * 3] = sum[3]; - outptr[TILE_M * 4] = sum[4]; - outptr[TILE_M * 5] = sum[5]; - outptr[TILE_M * 6] = sum[6]; - outptr[TILE_M * 7] = sum[7]; - outptr[1] = sum[8]; - outptr[TILE_M + 1] = sum[9]; - outptr[TILE_M * 2 + 1] = sum[10]; - outptr[TILE_M * 3 + 1] = sum[11]; - outptr[TILE_M * 4 + 1] = sum[12]; - outptr[TILE_M * 5 + 1] = sum[13]; - outptr[TILE_M * 6 + 1] = sum[14]; - outptr[TILE_M * 7 + 1] = sum[15]; - outptr += TILE_M * 8; - } - else - { - _mm_store_ps(ptmp, _sum0); - _mm_store_ps(ptmp + 4, _sum1); - _mm_store_ps(ptmp + 8, _sum2); - _mm_store_ps(ptmp + 12, _sum3); - } - - ptmp += 16; + __m128 _tmp0 = _mm_unpacklo_ps(_sum0, _sum2); + __m128 _tmp1 = _mm_unpackhi_ps(_sum0, _sum2); + __m128 _tmp2 = _mm_unpacklo_ps(_sum1, _sum3); + __m128 _tmp3 = _mm_unpackhi_ps(_sum1, _sum3); + _mm_storeu_ps(outptr, _tmp0); + _mm_storeu_ps(outptr + 4, _tmp1); + _mm_storeu_ps(outptr + 8, _tmp2); + _mm_storeu_ps(outptr + 12, _tmp3); + outptr += 2 * 8; } for (; jj + 3 < max_jj; jj += 4) { @@ -1837,8 +1547,10 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum0 = _mm_load_ps(ptmp); - _sum1 = _mm_load_ps(ptmp + 4); + __m128 _tmp0 = _mm_loadu_ps(outptr); + __m128 _tmp1 = _mm_loadu_ps(outptr + 4); + _sum0 = _mm_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); } int kk = 0; @@ -1851,28 +1563,11 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 4; } - if (k_end) - { - float sum[8]; - _mm_storeu_ps(sum, _sum0); - _mm_storeu_ps(sum + 4, _sum1); - outptr[0] = sum[0]; - outptr[TILE_M * 1] = sum[1]; - outptr[TILE_M * 2] = sum[2]; - outptr[TILE_M * 3] = sum[3]; - outptr[1] = sum[4]; - outptr[TILE_M * 1 + 1] = sum[5]; - outptr[TILE_M * 2 + 1] = sum[6]; - outptr[TILE_M * 3 + 1] = sum[7]; - outptr += TILE_M * 4; - } - else - { - _mm_store_ps(ptmp, _sum0); - _mm_store_ps(ptmp + 4, _sum1); - } - - ptmp += 8; + __m128 _tmp0 = _mm_unpacklo_ps(_sum0, _sum1); + __m128 _tmp1 = _mm_unpackhi_ps(_sum0, _sum1); + _mm_storeu_ps(outptr, _tmp0); + _mm_storeu_ps(outptr + 4, _tmp1); + outptr += 2 * 4; } #endif // __SSE2__ for (; jj + 1 < max_jj; jj += 2) @@ -1893,10 +1588,10 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - sum00 = ptmp[0]; - sum01 = ptmp[1]; - sum10 = ptmp[2]; - sum11 = ptmp[3]; + sum00 = outptr[0]; + sum01 = outptr[1]; + sum10 = outptr[2]; + sum11 = outptr[3]; } int kk = 0; @@ -1910,23 +1605,11 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 2; } - if (k_end) - { - outptr[0] = sum00; - outptr[1] = sum01; - outptr[TILE_M] = sum10; - outptr[TILE_M + 1] = sum11; - outptr += TILE_M * 2; - } - else - { - ptmp[0] = sum00; - ptmp[1] = sum01; - ptmp[2] = sum10; - ptmp[3] = sum11; - } - - ptmp += 4; + outptr[0] = sum00; + outptr[1] = sum01; + outptr[2] = sum10; + outptr[3] = sum11; + outptr += 2 * 2; } for (; jj < max_jj; jj++) { @@ -1942,8 +1625,8 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - sum0 = ptmp[0]; - sum1 = ptmp[1]; + sum0 = outptr[0]; + sum1 = outptr[1]; } int kk = 0; @@ -1955,28 +1638,18 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 1; } - if (k_end) - { - outptr[0] = sum0; - outptr[1] = sum1; - outptr += TILE_M; - } - else - { - ptmp[0] = sum0; - ptmp[1] = sum1; - } - - ptmp += 2; + outptr[0] = sum0; + outptr[1] = sum1; + outptr += 2; } - - pAT += max_kk * 2; } - for (; ii < max_ii; ii++) + } + for (; ii < max_ii; ii++) + { + for (int b = 0; b < batch; b++) { - const float* pB = pBT; - - float* outptr = (float*)top_blob.depth(b) + ii; + const float* pAT = AT_tile.row(b) + max_kk * ii; + const float* pB = BT_tile.row(b); int jj = 0; #if __SSE2__ @@ -1996,9 +1669,9 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum0 = _mm_load_ps(ptmp); - _sum1 = _mm_load_ps(ptmp + 4); - _sum2 = _mm_load_ps(ptmp + 8); + _sum0 = _mm_loadu_ps(outptr); + _sum1 = _mm_loadu_ps(outptr + 4); + _sum2 = _mm_loadu_ps(outptr + 8); } int kk = 0; @@ -2015,34 +1688,10 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 12; } - if (k_end) - { - float sum[12]; - _mm_storeu_ps(sum, _sum0); - _mm_storeu_ps(sum + 4, _sum1); - _mm_storeu_ps(sum + 8, _sum2); - outptr[0] = sum[0]; - outptr[TILE_M] = sum[1]; - outptr[TILE_M * 2] = sum[2]; - outptr[TILE_M * 3] = sum[3]; - outptr[TILE_M * 4] = sum[4]; - outptr[TILE_M * 5] = sum[5]; - outptr[TILE_M * 6] = sum[6]; - outptr[TILE_M * 7] = sum[7]; - outptr[TILE_M * 8] = sum[8]; - outptr[TILE_M * 9] = sum[9]; - outptr[TILE_M * 10] = sum[10]; - outptr[TILE_M * 11] = sum[11]; - outptr += TILE_M * 12; - } - else - { - _mm_store_ps(ptmp, _sum0); - _mm_store_ps(ptmp + 4, _sum1); - _mm_store_ps(ptmp + 8, _sum2); - } - - ptmp += 12; + _mm_storeu_ps(outptr, _sum0); + _mm_storeu_ps(outptr + 4, _sum1); + _mm_storeu_ps(outptr + 8, _sum2); + outptr += 12; } for (; jj + 7 < max_jj; jj += 8) { @@ -2058,8 +1707,8 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum0 = _mm_load_ps(ptmp); - _sum1 = _mm_load_ps(ptmp + 4); + _sum0 = _mm_loadu_ps(outptr); + _sum1 = _mm_loadu_ps(outptr + 4); } int kk = 0; @@ -2074,28 +1723,9 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 8; } - if (k_end) - { - float sum[8]; - _mm_storeu_ps(sum, _sum0); - _mm_storeu_ps(sum + 4, _sum1); - outptr[0] = sum[0]; - outptr[TILE_M] = sum[1]; - outptr[TILE_M * 2] = sum[2]; - outptr[TILE_M * 3] = sum[3]; - outptr[TILE_M * 4] = sum[4]; - outptr[TILE_M * 5] = sum[5]; - outptr[TILE_M * 6] = sum[6]; - outptr[TILE_M * 7] = sum[7]; - outptr += TILE_M * 8; - } - else - { - _mm_store_ps(ptmp, _sum0); - _mm_store_ps(ptmp + 4, _sum1); - } - - ptmp += 8; + _mm_storeu_ps(outptr, _sum0); + _mm_storeu_ps(outptr + 4, _sum1); + outptr += 8; } for (; jj + 3 < max_jj; jj += 4) { @@ -2109,7 +1739,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - _sum = _mm_load_ps(ptmp); + _sum = _mm_loadu_ps(outptr); } int kk = 0; @@ -2122,22 +1752,8 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 4; } - if (k_end) - { - float sum[4]; - _mm_storeu_ps(sum, _sum); - outptr[0] = sum[0]; - outptr[TILE_M] = sum[1]; - outptr[TILE_M * 2] = sum[2]; - outptr[TILE_M * 3] = sum[3]; - outptr += TILE_M * 4; - } - else - { - _mm_store_ps(ptmp, _sum); - } - - ptmp += 4; + _mm_storeu_ps(outptr, _sum); + outptr += 4; } #endif // __SSE2__ for (; jj + 1 < max_jj; jj += 2) @@ -2154,8 +1770,8 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - sum0 = ptmp[0]; - sum1 = ptmp[1]; + sum0 = outptr[0]; + sum1 = outptr[1]; } int kk = 0; @@ -2167,19 +1783,9 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 2; } - if (k_end) - { - outptr[0] = sum0; - outptr[TILE_M] = sum1; - outptr += TILE_M * 2; - } - else - { - ptmp[0] = sum0; - ptmp[1] = sum1; - } - - ptmp += 2; + outptr[0] = sum0; + outptr[1] = sum1; + outptr += 2; } for (; jj < max_jj; jj++) { @@ -2193,7 +1799,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& } else { - sum = ptmp[0]; + sum = outptr[0]; } int kk = 0; @@ -2204,20 +1810,9 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, Mat& pB += 1; } - if (k_end) - { - outptr[0] = sum; - outptr += TILE_M; - } - else - { - ptmp[0] = sum; - } - - ptmp += 1; + outptr[0] = sum; + outptr += 1; } - - pAT += max_kk; } } } @@ -2226,99 +1821,97 @@ static void get_optimal_tile_mnk(int M, int N, int K, int& TILE_M, int& TILE_N, { // resolve optimal tile size from cache size size_t l2_cache_size = get_cpu_level2_cache_size(); - int tile_size = (int)sqrt((float)l2_cache_size / 3 / sizeof(float)); + + // solve M + { + int tile_size = (int)sqrt((float)l2_cache_size / sizeof(float) / 3); #if __AVX512F__ - TILE_M = tile_size / 16 * 16; - TILE_N = tile_size / 4 * 4; - TILE_K = tile_size / 16 * 16; + TILE_M = tile_size / 16 * 16; #elif __AVX__ - TILE_M = tile_size / 8 * 8; - TILE_N = tile_size / 4 * 4; - TILE_K = tile_size / 8 * 8; + TILE_M = tile_size / 8 * 8; #elif __SSE2__ - TILE_M = tile_size / 4 * 4; - TILE_N = tile_size / 4 * 4; - TILE_K = tile_size / 4 * 4; + TILE_M = tile_size / 4 * 4; #else - TILE_M = tile_size / 2 * 2; - TILE_N = tile_size; - TILE_K = tile_size / 2 * 2; + TILE_M = tile_size / 2 * 2; #endif - if (K > 0) - { - int nn_K = (K + TILE_K - 1) / TILE_K; + TILE_M *= std::min(nT, get_physical_cpu_count()); + + int nn_M = (M + TILE_M - 1) / TILE_M; #if __AVX512F__ - TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 15) / 16 * 16); + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 15) / 16 * 16); #elif __AVX__ - TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 7) / 8 * 8); + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 7) / 8 * 8); #elif __SSE2__ - TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 3) / 4 * 4); + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 3) / 4 * 4); #else - TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 1) / 2 * 2); + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 1) / 2 * 2); #endif - if (nn_K == 1) + if (nT > 1) { - tile_size = (int)((float)l2_cache_size / 2 / sizeof(float) / TILE_K); - #if __AVX512F__ - TILE_M = tile_size / 16 * 16; - TILE_N = tile_size / 4 * 4; + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 15) / 16 * 16); #elif __AVX__ - TILE_M = tile_size / 8 * 8; - TILE_N = tile_size / 4 * 4; + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 7) / 8 * 8); #elif __SSE2__ - TILE_M = tile_size / 4 * 4; - TILE_N = tile_size / 4 * 4; + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 3) / 4 * 4); #else - TILE_M = tile_size / 2 * 2; - TILE_N = tile_size; + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 1) / 2 * 2); #endif } } - TILE_M *= std::min(nT, get_physical_cpu_count()); - - if (M > 0) + // solve K { - int nn_M = (M + TILE_M - 1) / TILE_M; + int tile_size = (int)(sqrt((float)l2_cache_size / sizeof(float)) - TILE_M); + #if __AVX512F__ - TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 15) / 16 * 16); + TILE_K = tile_size / 16 * 16; #elif __AVX__ - TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 7) / 8 * 8); + TILE_K = tile_size / 8 * 8; #elif __SSE2__ - TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 3) / 4 * 4); + TILE_K = tile_size / 4 * 4; #else - TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 1) / 2 * 2); + TILE_K = tile_size / 2 * 2; +#endif + + int nn_K = (K + TILE_K - 1) / TILE_K; +#if __AVX512F__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 15) / 16 * 16); +#elif __AVX__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 7) / 8 * 8); +#elif __SSE2__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 3) / 4 * 4); +#else + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 1) / 2 * 2); #endif } if (N > 0) { - int nn_N = (N + TILE_N - 1) / TILE_N; + int tile_size = (int)(((float)l2_cache_size / sizeof(float) - TILE_M * TILE_K) / (TILE_M + TILE_K)); + #if __AVX512F__ - TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); + TILE_N = tile_size / 4 * 4; #elif __AVX__ - TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); + TILE_N = tile_size / 4 * 4; #elif __SSE2__ - TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); + TILE_N = tile_size / 4 * 4; #else - TILE_N = std::min(TILE_N, (N + nn_N - 1) / nn_N); + TILE_N = tile_size; #endif - } - if (nT > 1) - { + int nn_N = (N + TILE_N - 1) / TILE_N; #if __AVX512F__ - TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 15) / 16 * 16); + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); #elif __AVX__ - TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 7) / 8 * 8); + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); #elif __SSE2__ - TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 3) / 4 * 4); + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); #else - TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 1) / 2 * 2); + TILE_N = std::min(TILE_N, (N + nn_N - 1) / nn_N); #endif } } @@ -2415,7 +2008,7 @@ static void conv3x3s1_winograd23_transform_kernel(const Mat& kernel, Mat& AT, in } } -static inline void conv3x3s1_winograd23_transform_input_tile(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +static inline void conv3x3s1_winograd23_transform_input_tile(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int nT) { // const float itm[4][4] = { // {1.0f, 0.0f, -1.0f, 0.0f}, @@ -2431,27 +2024,30 @@ static inline void conv3x3s1_winograd23_transform_input_tile(const Mat& bottom_b const int w_tiles = (w - 1) / 2; - float* ptmp = B; - - int kk = 0; + int nn_max_kk = 0; + int remain_max_kk_start = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ - for (; kk + 15 < max_kk; kk += 16) + nn_max_kk = max_kk / 16; + #pragma omp parallel for num_threads(nT) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) { + const int kk = ppkk * 16; + +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + float tmp[4][4][16]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(64)) -#else - __attribute__((aligned(64))) -#endif - float tmp[4][4][16]; - const float* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 2) + (tj * 2) * elempack; for (int m = 0; m < 4; m++) @@ -2549,6 +2145,12 @@ static inline void conv3x3s1_winograd23_transform_input_tile(const Mat& bottom_b r0 += w * elempack; } + + float* p0 = (float*)B + kk * max_jj * 16 + jj * 16; + float* p1 = p0 + max_jj * 16; + float* p2 = p0 + max_jj * 16 * 2; + float* p3 = p0 + max_jj * 16 * 3; + for (int m = 0; m < 4; m++) { __m512 _r0 = _mm512_load_ps(tmp[m][0]); @@ -2561,30 +2163,41 @@ static inline void conv3x3s1_winograd23_transform_input_tile(const Mat& bottom_b __m512 _tmp2 = _mm512_sub_ps(_r2, _r1); __m512 _tmp3 = _mm512_sub_ps(_r3, _r1); - _mm512_store_ps(ptmp, _tmp0); - _mm512_store_ps(ptmp + 16, _tmp1); - _mm512_store_ps(ptmp + 32, _tmp2); - _mm512_store_ps(ptmp + 48, _tmp3); - ptmp += 64; + _mm512_store_ps(p0, _tmp0); + _mm512_store_ps(p1, _tmp1); + _mm512_store_ps(p2, _tmp2); + _mm512_store_ps(p3, _tmp3); + + p0 += max_jj * 4 * 16; + p1 += max_jj * 4 * 16; + p2 += max_jj * 4 * 16; + p3 += max_jj * 4 * 16; } } } + remain_max_kk_start += nn_max_kk * 16; + nn_max_kk = (max_kk - remain_max_kk_start) / 8; +#else // __AVX512F__ + nn_max_kk = (max_kk - remain_max_kk_start) / 8; + #pragma omp parallel for num_threads(nT) #endif // __AVX512F__ - for (; kk + 7 < max_kk; kk += 8) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) { + const int kk = remain_max_kk_start + ppkk * 8; + +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + float tmp[4][4][8]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(32)) -#else - __attribute__((aligned(32))) -#endif - float tmp[4][4][8]; - const float* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 2) + (tj * 2) * elempack; for (int m = 0; m < 4; m++) @@ -2662,6 +2275,12 @@ static inline void conv3x3s1_winograd23_transform_input_tile(const Mat& bottom_b r0 += w * elempack; } + + float* p0 = (float*)B + kk * max_jj * 16 + jj * 8; + float* p1 = p0 + max_jj * 8; + float* p2 = p0 + max_jj * 8 * 2; + float* p3 = p0 + max_jj * 8 * 3; + for (int m = 0; m < 4; m++) { #if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) @@ -2681,30 +2300,41 @@ static inline void conv3x3s1_winograd23_transform_input_tile(const Mat& bottom_b __m256 _tmp2 = _mm256_sub_ps(_r2, _r1); __m256 _tmp3 = _mm256_sub_ps(_r3, _r1); - _mm256_store_ps(ptmp, _tmp0); - _mm256_store_ps(ptmp + 8, _tmp1); - _mm256_store_ps(ptmp + 16, _tmp2); - _mm256_store_ps(ptmp + 24, _tmp3); - ptmp += 32; + _mm256_store_ps(p0, _tmp0); + _mm256_store_ps(p1, _tmp1); + _mm256_store_ps(p2, _tmp2); + _mm256_store_ps(p3, _tmp3); + + p0 += max_jj * 4 * 8; + p1 += max_jj * 4 * 8; + p2 += max_jj * 4 * 8; + p3 += max_jj * 4 * 8; } } } + remain_max_kk_start += nn_max_kk * 8; + nn_max_kk = (max_kk - remain_max_kk_start) / 4; +#else // __AVX__ + nn_max_kk = (max_kk - remain_max_kk_start) / 4; + #pragma omp parallel for num_threads(nT) #endif // __AVX__ - for (; kk + 3 < max_kk; kk += 4) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) { + const int kk = remain_max_kk_start + ppkk * 4; + +#ifdef _MSC_VER + __declspec(align(16)) +#else + __attribute__((aligned(16))) +#endif + float tmp[4][4][4]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(16)) -#else - __attribute__((aligned(16))) -#endif - float tmp[4][4][4]; - const float* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 2) + (tj * 2) * elempack; for (int m = 0; m < 4; m++) @@ -2762,6 +2392,12 @@ static inline void conv3x3s1_winograd23_transform_input_tile(const Mat& bottom_b r0 += w * elempack; } + + float* p0 = (float*)B + kk * max_jj * 16 + jj * 4; + float* p1 = p0 + max_jj * 4; + float* p2 = p0 + max_jj * 4 * 2; + float* p3 = p0 + max_jj * 4 * 3; + for (int m = 0; m < 4; m++) { #if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) @@ -2781,25 +2417,36 @@ static inline void conv3x3s1_winograd23_transform_input_tile(const Mat& bottom_b __m128 _tmp2 = _mm_sub_ps(_r2, _r1); __m128 _tmp3 = _mm_sub_ps(_r3, _r1); - _mm_store_ps(ptmp, _tmp0); - _mm_store_ps(ptmp + 4, _tmp1); - _mm_store_ps(ptmp + 8, _tmp2); - _mm_store_ps(ptmp + 12, _tmp3); - ptmp += 16; + _mm_store_ps(p0, _tmp0); + _mm_store_ps(p1, _tmp1); + _mm_store_ps(p2, _tmp2); + _mm_store_ps(p3, _tmp3); + + p0 += max_jj * 4 * 4; + p1 += max_jj * 4 * 4; + p2 += max_jj * 4 * 4; + p3 += max_jj * 4 * 4; } } } + remain_max_kk_start += nn_max_kk * 4; + nn_max_kk = (max_kk - remain_max_kk_start) / 2; +#else // __SSE2__ + nn_max_kk = (max_kk - remain_max_kk_start) / 2; + #pragma omp parallel for num_threads(nT) #endif // __SSE2__ - for (; kk + 1 < max_kk; kk += 2) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) { + const int kk = remain_max_kk_start + ppkk * 2; + + float tmp[4][4][2]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; - float tmp[4][4][2]; - const float* r0 = bottom_blob.channel(k + kk).row(ti * 2) + (tj * 2); for (int m = 0; m < 4; m++) @@ -2850,6 +2497,12 @@ static inline void conv3x3s1_winograd23_transform_input_tile(const Mat& bottom_b r0 += w; } + + float* p0 = (float*)B + kk * max_jj * 16 + jj * 2; + float* p1 = p0 + max_jj * 2; + float* p2 = p0 + max_jj * 2 * 2; + float* p3 = p0 + max_jj * 2 * 3; + for (int m = 0; m < 4; m++) { float r00 = tmp[m][0][0]; @@ -2861,28 +2514,33 @@ static inline void conv3x3s1_winograd23_transform_input_tile(const Mat& bottom_b float r30 = tmp[m][3][0]; float r31 = tmp[m][3][1]; - ptmp[0] = r00 - r20; - ptmp[1] = r01 - r21; - ptmp[2] = r10 + r20; - ptmp[3] = r11 + r21; - ptmp[4] = r20 - r10; - ptmp[5] = r21 - r11; - ptmp[6] = r30 - r10; - ptmp[7] = r31 - r11; - ptmp += 8; + p0[0] = r00 - r20; + p0[1] = r01 - r21; + p1[0] = r10 + r20; + p1[1] = r11 + r21; + p2[0] = r20 - r10; + p2[1] = r21 - r11; + p3[0] = r30 - r10; + p3[1] = r31 - r11; + + p0 += max_jj * 4 * 2; + p1 += max_jj * 4 * 2; + p2 += max_jj * 4 * 2; + p3 += max_jj * 4 * 2; } } } - for (; kk < max_kk; kk++) + remain_max_kk_start += nn_max_kk * 2; + for (int kk = remain_max_kk_start; kk < max_kk; kk++) { + float tmp[4][4]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; - float tmp[4][4]; - const float* r0123 = bottom_blob.channel(k + kk).row(ti * 2) + (tj * 2); for (int m = 0; m < 4; m++) @@ -2910,6 +2568,12 @@ static inline void conv3x3s1_winograd23_transform_input_tile(const Mat& bottom_b r0123 += w; } + + float* p0 = (float*)B + kk * max_jj * 16 + jj; + float* p1 = p0 + max_jj; + float* p2 = p0 + max_jj * 2; + float* p3 = p0 + max_jj * 3; + for (int m = 0; m < 4; m++) { float r0 = tmp[m][0]; @@ -2917,11 +2581,15 @@ static inline void conv3x3s1_winograd23_transform_input_tile(const Mat& bottom_b float r2 = tmp[m][2]; float r3 = tmp[m][3]; - ptmp[0] = r0 - r2; - ptmp[1] = r1 + r2; - ptmp[2] = r2 - r1; - ptmp[3] = r3 - r1; - ptmp += 4; + p0[0] = r0 - r2; + p1[0] = r1 + r2; + p2[0] = r2 - r1; + p3[0] = r3 - r1; + + p0 += max_jj * 4; + p1 += max_jj * 4; + p2 += max_jj * 4; + p3 += max_jj * 4; } } } @@ -2951,28 +2619,26 @@ static inline void conv3x3s1_winograd23_transform_output_tile(const Mat& top_til { __m512 _bias0 = biasptr ? _mm512_loadu_ps(biasptr + i + ii) : _mm512_setzero_ps(); +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + float tmp[2][4][16]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(64)) -#else - __attribute__((aligned(64))) -#endif - float tmp[2][4][16]; - - float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + const float* r0 = (const float*)top_tile + ii * max_jj * 16 + jj * 16; + const float* r1 = r0 + max_jj * 16; + const float* r2 = r0 + max_jj * 16 * 2; + const float* r3 = r0 + max_jj * 16 * 3; for (int m = 0; m < 4; m++) { - const float* r0 = top_tile.depth(m * 4).row(jj) + ii; - const float* r1 = top_tile.depth(m * 4 + 1).row(jj) + ii; - const float* r2 = top_tile.depth(m * 4 + 2).row(jj) + ii; - const float* r3 = top_tile.depth(m * 4 + 3).row(jj) + ii; - __m512 _r0 = _mm512_load_ps(r0); __m512 _r1 = _mm512_load_ps(r1); __m512 _r2 = _mm512_load_ps(r2); @@ -2983,7 +2649,15 @@ static inline void conv3x3s1_winograd23_transform_output_tile(const Mat& top_til _mm512_store_ps(tmp[0][m], _tmp0); _mm512_store_ps(tmp[1][m], _tmp1); + + r0 += max_jj * 4 * 16; + r1 += max_jj * 4 * 16; + r2 += max_jj * 4 * 16; + r3 += max_jj * 4 * 16; } + + float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + for (int m = 0; m < 2; m++) { if (ti * 2 + m >= outh) @@ -3105,28 +2779,26 @@ static inline void conv3x3s1_winograd23_transform_output_tile(const Mat& top_til { __m256 _bias0 = biasptr ? _mm256_loadu_ps(biasptr + i + ii) : _mm256_setzero_ps(); +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + float tmp[2][4][8]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(32)) -#else - __attribute__((aligned(32))) -#endif - float tmp[2][4][8]; - - float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + const float* r0 = (const float*)top_tile + ii * max_jj * 16 + jj * 8; + const float* r1 = r0 + max_jj * 8; + const float* r2 = r0 + max_jj * 8 * 2; + const float* r3 = r0 + max_jj * 8 * 3; for (int m = 0; m < 4; m++) { - const float* r0 = top_tile.depth(m * 4).row(jj) + ii; - const float* r1 = top_tile.depth(m * 4 + 1).row(jj) + ii; - const float* r2 = top_tile.depth(m * 4 + 2).row(jj) + ii; - const float* r3 = top_tile.depth(m * 4 + 3).row(jj) + ii; - __m256 _r0 = _mm256_load_ps(r0); __m256 _r1 = _mm256_load_ps(r1); __m256 _r2 = _mm256_load_ps(r2); @@ -3142,7 +2814,15 @@ static inline void conv3x3s1_winograd23_transform_output_tile(const Mat& top_til _mm256_store_ps(tmp[0][m], _tmp0); _mm256_store_ps(tmp[1][m], _tmp1); #endif + + r0 += max_jj * 4 * 8; + r1 += max_jj * 4 * 8; + r2 += max_jj * 4 * 8; + r3 += max_jj * 4 * 8; } + + float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + for (int m = 0; m < 2; m++) { if (ti * 2 + m >= outh) @@ -3229,28 +2909,26 @@ static inline void conv3x3s1_winograd23_transform_output_tile(const Mat& top_til { __m128 _bias0 = biasptr ? _mm_loadu_ps(biasptr + i + ii) : _mm_setzero_ps(); +#ifdef _MSC_VER + __declspec(align(16)) +#else + __attribute__((aligned(16))) +#endif + float tmp[2][4][4]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(16)) -#else - __attribute__((aligned(16))) -#endif - float tmp[2][4][4]; - - float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + const float* r0 = (const float*)top_tile + ii * max_jj * 16 + jj * 4; + const float* r1 = r0 + max_jj * 4; + const float* r2 = r0 + max_jj * 4 * 2; + const float* r3 = r0 + max_jj * 4 * 3; for (int m = 0; m < 4; m++) { - const float* r0 = top_tile.depth(m * 4).row(jj) + ii; - const float* r1 = top_tile.depth(m * 4 + 1).row(jj) + ii; - const float* r2 = top_tile.depth(m * 4 + 2).row(jj) + ii; - const float* r3 = top_tile.depth(m * 4 + 3).row(jj) + ii; - __m128 _r0 = _mm_load_ps(r0); __m128 _r1 = _mm_load_ps(r1); __m128 _r2 = _mm_load_ps(r2); @@ -3266,7 +2944,15 @@ static inline void conv3x3s1_winograd23_transform_output_tile(const Mat& top_til _mm_store_ps(tmp[0][m], _tmp0); _mm_store_ps(tmp[1][m], _tmp1); #endif + + r0 += max_jj * 4 * 4; + r1 += max_jj * 4 * 4; + r2 += max_jj * 4 * 4; + r3 += max_jj * 4 * 4; } + + float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + for (int m = 0; m < 2; m++) { if (ti * 2 + m >= outh) @@ -3327,32 +3013,34 @@ static inline void conv3x3s1_winograd23_transform_output_tile(const Mat& top_til float bias0 = biasptr ? biasptr[i + ii] : 0.f; float bias1 = biasptr ? biasptr[i + ii + 1] : 0.f; + float tmp[2][4][2]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; - float tmp[2][4][2]; - - float* outptr0 = top_blob.channel(i + ii).row(ti * 2) + (tj * 2); + const float* r0 = (const float*)top_tile + ii * max_jj * 16 + jj * 2; + const float* r1 = r0 + max_jj * 2; + const float* r2 = r0 + max_jj * 2 * 2; + const float* r3 = r0 + max_jj * 2 * 3; for (int m = 0; m < 4; m++) { - float r00 = top_tile.depth(m * 4).row(jj)[ii]; - float r01 = top_tile.depth(m * 4).row(jj)[ii + 1]; - float r10 = top_tile.depth(m * 4 + 1).row(jj)[ii]; - float r11 = top_tile.depth(m * 4 + 1).row(jj)[ii + 1]; - float r20 = top_tile.depth(m * 4 + 2).row(jj)[ii]; - float r21 = top_tile.depth(m * 4 + 2).row(jj)[ii + 1]; - float r30 = top_tile.depth(m * 4 + 3).row(jj)[ii]; - float r31 = top_tile.depth(m * 4 + 3).row(jj)[ii + 1]; - - tmp[0][m][0] = r00 + r10 + r20; - tmp[0][m][1] = r01 + r11 + r21; - tmp[1][m][0] = r10 - r20 + r30; - tmp[1][m][1] = r11 - r21 + r31; + tmp[0][m][0] = r0[0] + r1[0] + r2[0]; + tmp[0][m][1] = r0[1] + r1[1] + r2[1]; + tmp[1][m][0] = r1[0] - r2[0] + r3[0]; + tmp[1][m][1] = r1[1] - r2[1] + r3[1]; + + r0 += max_jj * 4 * 2; + r1 += max_jj * 4 * 2; + r2 += max_jj * 4 * 2; + r3 += max_jj * 4 * 2; } + + float* outptr0 = top_blob.channel(i + ii).row(ti * 2) + (tj * 2); + for (int m = 0; m < 2; m++) { if (ti * 2 + m >= outh) @@ -3393,26 +3081,32 @@ static inline void conv3x3s1_winograd23_transform_output_tile(const Mat& top_til { float bias0 = biasptr ? biasptr[i + ii] : 0.f; + float tmp[2][4]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; - float tmp[2][4]; - - float* outptr0 = top_blob.channel(i + ii).row(ti * 2) + (tj * 2); + const float* r0 = (const float*)top_tile + ii * max_jj * 16 + jj; + const float* r1 = r0 + max_jj; + const float* r2 = r0 + max_jj * 2; + const float* r3 = r0 + max_jj * 3; for (int m = 0; m < 4; m++) { - float r0 = top_tile.depth(m * 4).row(jj)[ii]; - float r1 = top_tile.depth(m * 4 + 1).row(jj)[ii]; - float r2 = top_tile.depth(m * 4 + 2).row(jj)[ii]; - float r3 = top_tile.depth(m * 4 + 3).row(jj)[ii]; + tmp[0][m] = r0[0] + r1[0] + r2[0]; + tmp[1][m] = r1[0] - r2[0] + r3[0]; - tmp[0][m] = r0 + r1 + r2; - tmp[1][m] = r1 - r2 + r3; + r0 += max_jj * 4; + r1 += max_jj * 4; + r2 += max_jj * 4; + r3 += max_jj * 4; } + + float* outptr0 = top_blob.channel(i + ii).row(ti * 2) + (tj * 2); + for (int m = 0; m < 2; m++) { if (ti * 2 + m >= outh) @@ -3438,7 +3132,7 @@ static inline void conv3x3s1_winograd23_transform_output_tile(const Mat& top_til } } -static void conv3x3s1_winograd23(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, const Mat& bias, const Option& opt) +static void conv3x3s1_winograd23(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, const Mat& bias, int nT, const Option& opt) { int outw = top_blob.w; int outh = top_blob.h; @@ -3455,58 +3149,76 @@ static void conv3x3s1_winograd23(const Mat& bottom_blob, Mat& top_blob, const Ma // NCNN_LOGE("conv3x3s1_winograd23 %d %d %d", M, N, K); - int nT = opt.num_threads; - int TILE_M, TILE_N, TILE_K; get_optimal_tile_mnk(M, N, K, TILE_M, TILE_N, TILE_K, nT); const int nn_M = (M + TILE_M - 1) / TILE_M; const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; // NCNN_LOGE("TILE M/N/K = %d %d %d -> %d %d %d", M, N, K, TILE_M, TILE_N, TILE_K); - Mat B_tileX(B * TILE_N * TILE_K, 1, nT, 4u, opt.blob_allocator); Mat BT(TILE_K * TILE_N, B, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 4u, opt.blob_allocator); - #pragma omp parallel for num_threads(nT) - for (int ppj = 0; ppj < nn_N; ppj++) + const int nn_NK = nn_N * nn_K; + + if (nT > 1 && nn_NK < nT) { - const int j = ppj * TILE_N; + Mat B_tile(TILE_N * B * TILE_K, 4u, opt.blob_allocator); - Mat B_tile = B_tileX.channel(get_omp_thread_num()); + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; - const int max_jj = std::min((N - j), TILE_N); + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; - for (int k = 0; k < K; k += TILE_K) - { + const int max_jj = std::min((N - j), TILE_N); const int max_kk = std::min((K - k), TILE_K); // transform input - conv3x3s1_winograd23_transform_input_tile(bottom_blob, B_tile, j, max_jj, k, max_kk); + conv3x3s1_winograd23_transform_input_tile(bottom_blob, B_tile, j, max_jj, k, max_kk, nT); Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); - transpose_pack_B_tile(B_tile, BT_tile, B, max_jj, max_kk); + transpose_pack_B_tile(B_tile, BT_tile, B, max_jj, max_kk, nT); } } - - Mat tmpX; - if (TILE_K < K) + else { - tmpX.create(TILE_M * TILE_N, B, nT, 4u, opt.blob_allocator); + Mat B_tileX(TILE_N * B * TILE_K, 1, nT, 4u, opt.blob_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat B_tile = B_tileX.channel(get_omp_thread_num()); + + // transform input + conv3x3s1_winograd23_transform_input_tile(bottom_blob, B_tile, j, max_jj, k, max_kk, 1); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile(B_tile, BT_tile, B, max_jj, max_kk, 1); + } } - Mat top_tileX(TILE_M, TILE_N, B, nT, 4u, opt.blob_allocator); + Mat top_tileX(TILE_N * B * TILE_M, 1, nT, 4u, opt.blob_allocator); #pragma omp parallel for num_threads(nT) for (int ppj = 0; ppj < nn_M; ppj++) { const int i = ppj * TILE_M; - Mat tmp; - if (K > TILE_K) - tmp = tmpX.channel(get_omp_thread_num()); - Mat top_tile = top_tileX.channel(get_omp_thread_num()); const int max_ii = std::min((M - i), TILE_M); @@ -3523,9 +3235,7 @@ static void conv3x3s1_winograd23(const Mat& bottom_blob, Mat& top_blob, const Ma const Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); - bool k_end = k + TILE_K >= K; - - gemm_transB_packed_tile(AT_tile, BT_tile, top_tile, tmp, B, max_ii, max_jj, k, max_kk, k_end); + gemm_transB_packed_tile(AT_tile, BT_tile, top_tile, B, max_ii, max_jj, k, max_kk); } // transform output @@ -3638,7 +3348,7 @@ static void conv3x3s1_winograd43_transform_kernel(const Mat& kernel, Mat& AT, in } } -static inline void conv3x3s1_winograd43_transform_input_tile(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +static inline void conv3x3s1_winograd43_transform_input_tile(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int nT) { // const float itm[6][6] = { // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, @@ -3656,27 +3366,30 @@ static inline void conv3x3s1_winograd43_transform_input_tile(const Mat& bottom_b const int w_tiles = (w + 1) / 4; - float* ptmp = B; - - int kk = 0; + int nn_max_kk = 0; + int remain_max_kk_start = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ - for (; kk + 15 < max_kk; kk += 16) + nn_max_kk = max_kk / 16; + #pragma omp parallel for num_threads(nT) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) { + const int kk = ppkk * 16; + +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + float tmp[6][6][16]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(64)) -#else - __attribute__((aligned(64))) -#endif - float tmp[6][6][16]; - const float* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 4) + (tj * 4) * elempack; __m512 _vm5 = _mm512_set1_ps(-5.f); @@ -3794,6 +3507,14 @@ static inline void conv3x3s1_winograd43_transform_input_tile(const Mat& bottom_b r0 += w * elempack; } + + float* p0 = (float*)B + kk * max_jj * 36 + jj * 16; + float* p1 = p0 + max_jj * 16; + float* p2 = p0 + max_jj * 16 * 2; + float* p3 = p0 + max_jj * 16 * 3; + float* p4 = p0 + max_jj * 16 * 4; + float* p5 = p0 + max_jj * 16 * 5; + for (int m = 0; m < 6; m++) { __m512 _r0 = _mm512_load_ps(tmp[m][0]); @@ -3810,32 +3531,45 @@ static inline void conv3x3s1_winograd43_transform_input_tile(const Mat& bottom_b __m512 _tmp4 = _mm512_fmadd_ps(_v2, _mm512_sub_ps(_r1, _r3), _mm512_sub_ps(_r4, _r2)); __m512 _tmp5 = _mm512_fmadd_ps(_vm5, _r3, _mm512_fmadd_ps(_v4, _r1, _r5)); - _mm512_store_ps(ptmp, _tmp0); - _mm512_store_ps(ptmp + 16, _tmp1); - _mm512_store_ps(ptmp + 32, _tmp2); - _mm512_store_ps(ptmp + 48, _tmp3); - _mm512_store_ps(ptmp + 64, _tmp4); - _mm512_store_ps(ptmp + 80, _tmp5); - ptmp += 96; + _mm512_store_ps(p0, _tmp0); + _mm512_store_ps(p1, _tmp1); + _mm512_store_ps(p2, _tmp2); + _mm512_store_ps(p3, _tmp3); + _mm512_store_ps(p4, _tmp4); + _mm512_store_ps(p5, _tmp5); + + p0 += max_jj * 6 * 16; + p1 += max_jj * 6 * 16; + p2 += max_jj * 6 * 16; + p3 += max_jj * 6 * 16; + p4 += max_jj * 6 * 16; + p5 += max_jj * 6 * 16; } } } + remain_max_kk_start += nn_max_kk * 16; + nn_max_kk = (max_kk - remain_max_kk_start) / 8; +#else // __AVX512F__ + nn_max_kk = (max_kk - remain_max_kk_start) / 8; + #pragma omp parallel for num_threads(nT) #endif // __AVX512F__ - for (; kk + 7 < max_kk; kk += 8) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) { + const int kk = remain_max_kk_start + ppkk * 8; + +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + float tmp[6][6][8]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(32)) -#else - __attribute__((aligned(32))) -#endif - float tmp[6][6][8]; - const float* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 4) + (tj * 4) * elempack; __m256 _vm5 = _mm256_set1_ps(-5.f); @@ -3931,6 +3665,14 @@ static inline void conv3x3s1_winograd43_transform_input_tile(const Mat& bottom_b r0 += w * elempack; } + + float* p0 = (float*)B + kk * max_jj * 36 + jj * 8; + float* p1 = p0 + max_jj * 8; + float* p2 = p0 + max_jj * 8 * 2; + float* p3 = p0 + max_jj * 8 * 3; + float* p4 = p0 + max_jj * 8 * 4; + float* p5 = p0 + max_jj * 8 * 5; + for (int m = 0; m < 6; m++) { #if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) @@ -3956,32 +3698,45 @@ static inline void conv3x3s1_winograd43_transform_input_tile(const Mat& bottom_b __m256 _tmp4 = _mm256_comp_fmadd_ps(_v2, _mm256_sub_ps(_r1, _r3), _mm256_sub_ps(_r4, _r2)); __m256 _tmp5 = _mm256_comp_fmadd_ps(_vm5, _r3, _mm256_comp_fmadd_ps(_v4, _r1, _r5)); - _mm256_store_ps(ptmp, _tmp0); - _mm256_store_ps(ptmp + 8, _tmp1); - _mm256_store_ps(ptmp + 16, _tmp2); - _mm256_store_ps(ptmp + 24, _tmp3); - _mm256_store_ps(ptmp + 32, _tmp4); - _mm256_store_ps(ptmp + 40, _tmp5); - ptmp += 48; + _mm256_store_ps(p0, _tmp0); + _mm256_store_ps(p1, _tmp1); + _mm256_store_ps(p2, _tmp2); + _mm256_store_ps(p3, _tmp3); + _mm256_store_ps(p4, _tmp4); + _mm256_store_ps(p5, _tmp5); + + p0 += max_jj * 6 * 8; + p1 += max_jj * 6 * 8; + p2 += max_jj * 6 * 8; + p3 += max_jj * 6 * 8; + p4 += max_jj * 6 * 8; + p5 += max_jj * 6 * 8; } } } + remain_max_kk_start += nn_max_kk * 8; + nn_max_kk = (max_kk - remain_max_kk_start) / 4; +#else // __AVX__ + nn_max_kk = (max_kk - remain_max_kk_start) / 4; + #pragma omp parallel for num_threads(nT) #endif // __AVX__ - for (; kk + 3 < max_kk; kk += 4) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) { + const int kk = remain_max_kk_start + ppkk * 4; + +#ifdef _MSC_VER + __declspec(align(16)) +#else + __attribute__((aligned(16))) +#endif + float tmp[6][6][4]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(16)) -#else - __attribute__((aligned(16))) -#endif - float tmp[6][6][4]; - const float* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 4) + (tj * 4) * elempack; __m128 _vm5 = _mm_set1_ps(-5.f); @@ -4057,6 +3812,14 @@ static inline void conv3x3s1_winograd43_transform_input_tile(const Mat& bottom_b r0 += w * elempack; } + + float* p0 = (float*)B + kk * max_jj * 36 + jj * 4; + float* p1 = p0 + max_jj * 4; + float* p2 = p0 + max_jj * 4 * 2; + float* p3 = p0 + max_jj * 4 * 3; + float* p4 = p0 + max_jj * 4 * 4; + float* p5 = p0 + max_jj * 4 * 5; + for (int m = 0; m < 6; m++) { #if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) @@ -4082,27 +3845,40 @@ static inline void conv3x3s1_winograd43_transform_input_tile(const Mat& bottom_b __m128 _tmp4 = _mm_comp_fmadd_ps(_v2, _mm_sub_ps(_r1, _r3), _mm_sub_ps(_r4, _r2)); __m128 _tmp5 = _mm_comp_fmadd_ps(_vm5, _r3, _mm_comp_fmadd_ps(_v4, _r1, _r5)); - _mm_store_ps(ptmp, _tmp0); - _mm_store_ps(ptmp + 4, _tmp1); - _mm_store_ps(ptmp + 8, _tmp2); - _mm_store_ps(ptmp + 12, _tmp3); - _mm_store_ps(ptmp + 16, _tmp4); - _mm_store_ps(ptmp + 20, _tmp5); - ptmp += 24; + _mm_store_ps(p0, _tmp0); + _mm_store_ps(p1, _tmp1); + _mm_store_ps(p2, _tmp2); + _mm_store_ps(p3, _tmp3); + _mm_store_ps(p4, _tmp4); + _mm_store_ps(p5, _tmp5); + + p0 += max_jj * 6 * 4; + p1 += max_jj * 6 * 4; + p2 += max_jj * 6 * 4; + p3 += max_jj * 6 * 4; + p4 += max_jj * 6 * 4; + p5 += max_jj * 6 * 4; } } } + remain_max_kk_start += nn_max_kk * 4; + nn_max_kk = (max_kk - remain_max_kk_start) / 2; +#else // __SSE2__ + nn_max_kk = (max_kk - remain_max_kk_start) / 2; + #pragma omp parallel for num_threads(nT) #endif // __SSE2__ - for (; kk + 1 < max_kk; kk += 2) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) { + const int kk = remain_max_kk_start + ppkk * 2; + + float tmp[6][6][2]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; - float tmp[6][6][2]; - const float* r0 = bottom_blob.channel(k + kk).row(ti * 4) + (tj * 4); for (int m = 0; m < 6; m++) @@ -4171,6 +3947,14 @@ static inline void conv3x3s1_winograd43_transform_input_tile(const Mat& bottom_b r0 += w; } + + float* p0 = (float*)B + kk * max_jj * 36 + jj * 2; + float* p1 = p0 + max_jj * 2; + float* p2 = p0 + max_jj * 2 * 2; + float* p3 = p0 + max_jj * 2 * 3; + float* p4 = p0 + max_jj * 2 * 4; + float* p5 = p0 + max_jj * 2 * 5; + for (int m = 0; m < 6; m++) { float r00 = tmp[m][0][0]; @@ -4186,32 +3970,39 @@ static inline void conv3x3s1_winograd43_transform_input_tile(const Mat& bottom_b float r50 = tmp[m][5][0]; float r51 = tmp[m][5][1]; - ptmp[0] = r00 * 4.f - r20 * 5.f + r40; - ptmp[1] = r01 * 4.f - r21 * 5.f + r41; - ptmp[2] = -r10 * 4.f - r20 * 4.f + r30 + r40; - ptmp[3] = -r11 * 4.f - r21 * 4.f + r31 + r41; - ptmp[4] = r10 * 4.f - r20 * 4.f - r30 + r40; - ptmp[5] = r11 * 4.f - r21 * 4.f - r31 + r41; - ptmp[6] = -r10 * 2.f - r20 + r30 * 2.f + r40; - ptmp[7] = -r11 * 2.f - r21 + r31 * 2.f + r41; - ptmp[8] = r10 * 2.f - r20 - r30 * 2.f + r40; - ptmp[9] = r11 * 2.f - r21 - r31 * 2.f + r41; - ptmp[10] = r10 * 4.f - r30 * 5.f + r50; - ptmp[11] = r11 * 4.f - r31 * 5.f + r51; - ptmp += 12; + p0[0] = r00 * 4.f - r20 * 5.f + r40; + p0[1] = r01 * 4.f - r21 * 5.f + r41; + p1[0] = -r10 * 4.f - r20 * 4.f + r30 + r40; + p1[1] = -r11 * 4.f - r21 * 4.f + r31 + r41; + p2[0] = r10 * 4.f - r20 * 4.f - r30 + r40; + p2[1] = r11 * 4.f - r21 * 4.f - r31 + r41; + p3[0] = -r10 * 2.f - r20 + r30 * 2.f + r40; + p3[1] = -r11 * 2.f - r21 + r31 * 2.f + r41; + p4[0] = r10 * 2.f - r20 - r30 * 2.f + r40; + p4[1] = r11 * 2.f - r21 - r31 * 2.f + r41; + p5[0] = r10 * 4.f - r30 * 5.f + r50; + p5[1] = r11 * 4.f - r31 * 5.f + r51; + + p0 += max_jj * 6 * 2; + p1 += max_jj * 6 * 2; + p2 += max_jj * 6 * 2; + p3 += max_jj * 6 * 2; + p4 += max_jj * 6 * 2; + p5 += max_jj * 6 * 2; } } } - for (; kk < max_kk; kk++) + remain_max_kk_start += nn_max_kk * 2; + for (int kk = remain_max_kk_start; kk < max_kk; kk++) { + float tmp[6][6]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; - float tmp[6][6]; - const float* r0123 = bottom_blob.channel(k + kk).row(ti * 4) + (tj * 4); for (int m = 0; m < 6; m++) @@ -4245,6 +4036,14 @@ static inline void conv3x3s1_winograd43_transform_input_tile(const Mat& bottom_b r0123 += w; } + + float* p0 = (float*)B + kk * max_jj * 36 + jj; + float* p1 = p0 + max_jj; + float* p2 = p0 + max_jj * 2; + float* p3 = p0 + max_jj * 3; + float* p4 = p0 + max_jj * 4; + float* p5 = p0 + max_jj * 5; + for (int m = 0; m < 6; m++) { float r0 = tmp[m][0]; @@ -4254,13 +4053,19 @@ static inline void conv3x3s1_winograd43_transform_input_tile(const Mat& bottom_b float r4 = tmp[m][4]; float r5 = tmp[m][5]; - ptmp[0] = r0 * 4.f - r2 * 5.f + r4; - ptmp[1] = -r1 * 4.f - r2 * 4.f + r3 + r4; - ptmp[2] = r1 * 4.f - r2 * 4.f - r3 + r4; - ptmp[3] = -r1 * 2.f - r2 + r3 * 2.f + r4; - ptmp[4] = r1 * 2.f - r2 - r3 * 2.f + r4; - ptmp[5] = r1 * 4.f - r3 * 5.f + r5; - ptmp += 6; + p0[0] = r0 * 4.f - r2 * 5.f + r4; + p1[0] = -r1 * 4.f - r2 * 4.f + r3 + r4; + p2[0] = r1 * 4.f - r2 * 4.f - r3 + r4; + p3[0] = -r1 * 2.f - r2 + r3 * 2.f + r4; + p4[0] = r1 * 2.f - r2 - r3 * 2.f + r4; + p5[0] = r1 * 4.f - r3 * 5.f + r5; + + p0 += max_jj * 6; + p1 += max_jj * 6; + p2 += max_jj * 6; + p3 += max_jj * 6; + p4 += max_jj * 6; + p5 += max_jj * 6; } } } @@ -4292,20 +4097,25 @@ static inline void conv3x3s1_winograd43_transform_output_tile(const Mat& top_til { __m512 _bias0 = biasptr ? _mm512_loadu_ps(biasptr + i + ii) : _mm512_setzero_ps(); +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + float tmp[4][6][16]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(64)) -#else - __attribute__((aligned(64))) -#endif - float tmp[4][6][16]; - - float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + const float* r0 = (const float*)top_tile + ii * max_jj * 36 + jj * 16; + const float* r1 = r0 + max_jj * 16; + const float* r2 = r0 + max_jj * 16 * 2; + const float* r3 = r0 + max_jj * 16 * 3; + const float* r4 = r0 + max_jj * 16 * 4; + const float* r5 = r0 + max_jj * 16 * 5; __m512 _v2 = _mm512_set1_ps(2.f); __m512 _v4 = _mm512_set1_ps(4.f); @@ -4313,13 +4123,6 @@ static inline void conv3x3s1_winograd43_transform_output_tile(const Mat& top_til for (int m = 0; m < 6; m++) { - const float* r0 = top_tile.depth(m * 6).row(jj) + ii; - const float* r1 = top_tile.depth(m * 6 + 1).row(jj) + ii; - const float* r2 = top_tile.depth(m * 6 + 2).row(jj) + ii; - const float* r3 = top_tile.depth(m * 6 + 3).row(jj) + ii; - const float* r4 = top_tile.depth(m * 6 + 4).row(jj) + ii; - const float* r5 = top_tile.depth(m * 6 + 5).row(jj) + ii; - __m512 _r0 = _mm512_load_ps(r0); __m512 _r1 = _mm512_load_ps(r1); __m512 _r2 = _mm512_load_ps(r2); @@ -4340,7 +4143,17 @@ static inline void conv3x3s1_winograd43_transform_output_tile(const Mat& top_til _mm512_store_ps(tmp[1][m], _tmp1); _mm512_store_ps(tmp[2][m], _tmp2); _mm512_store_ps(tmp[3][m], _tmp3); + + r0 += max_jj * 6 * 16; + r1 += max_jj * 6 * 16; + r2 += max_jj * 6 * 16; + r3 += max_jj * 6 * 16; + r4 += max_jj * 6 * 16; + r5 += max_jj * 6 * 16; } + + float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + for (int m = 0; m < 4; m++) { if (ti * 4 + m >= outh) @@ -4534,20 +4347,25 @@ static inline void conv3x3s1_winograd43_transform_output_tile(const Mat& top_til { __m256 _bias0 = biasptr ? _mm256_loadu_ps(biasptr + i + ii) : _mm256_setzero_ps(); +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + float tmp[4][6][8]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(32)) -#else - __attribute__((aligned(32))) -#endif - float tmp[4][6][8]; - - float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + const float* r0 = (const float*)top_tile + ii * max_jj * 36 + jj * 8; + const float* r1 = r0 + max_jj * 8; + const float* r2 = r0 + max_jj * 8 * 2; + const float* r3 = r0 + max_jj * 8 * 3; + const float* r4 = r0 + max_jj * 8 * 4; + const float* r5 = r0 + max_jj * 8 * 5; __m256 _v2 = _mm256_set1_ps(2.f); __m256 _v4 = _mm256_set1_ps(4.f); @@ -4555,13 +4373,6 @@ static inline void conv3x3s1_winograd43_transform_output_tile(const Mat& top_til for (int m = 0; m < 6; m++) { - const float* r0 = top_tile.depth(m * 6).row(jj) + ii; - const float* r1 = top_tile.depth(m * 6 + 1).row(jj) + ii; - const float* r2 = top_tile.depth(m * 6 + 2).row(jj) + ii; - const float* r3 = top_tile.depth(m * 6 + 3).row(jj) + ii; - const float* r4 = top_tile.depth(m * 6 + 4).row(jj) + ii; - const float* r5 = top_tile.depth(m * 6 + 5).row(jj) + ii; - __m256 _r0 = _mm256_load_ps(r0); __m256 _r1 = _mm256_load_ps(r1); __m256 _r2 = _mm256_load_ps(r2); @@ -4589,7 +4400,17 @@ static inline void conv3x3s1_winograd43_transform_output_tile(const Mat& top_til _mm256_store_ps(tmp[2][m], _tmp2); _mm256_store_ps(tmp[3][m], _tmp3); #endif + + r0 += max_jj * 6 * 8; + r1 += max_jj * 6 * 8; + r2 += max_jj * 6 * 8; + r3 += max_jj * 6 * 8; + r4 += max_jj * 6 * 8; + r5 += max_jj * 6 * 8; } + + float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + for (int m = 0; m < 4; m++) { if (ti * 4 + m >= outh) @@ -4720,20 +4541,25 @@ static inline void conv3x3s1_winograd43_transform_output_tile(const Mat& top_til { __m128 _bias0 = biasptr ? _mm_loadu_ps(biasptr + i + ii) : _mm_setzero_ps(); +#ifdef _MSC_VER + __declspec(align(16)) +#else + __attribute__((aligned(16))) +#endif + float tmp[4][6][4]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(16)) -#else - __attribute__((aligned(16))) -#endif - float tmp[4][6][4]; - - float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + const float* r0 = (const float*)top_tile + ii * max_jj * 36 + jj * 4; + const float* r1 = r0 + max_jj * 4; + const float* r2 = r0 + max_jj * 4 * 2; + const float* r3 = r0 + max_jj * 4 * 3; + const float* r4 = r0 + max_jj * 4 * 4; + const float* r5 = r0 + max_jj * 4 * 5; __m128 _v2 = _mm_set1_ps(2.f); __m128 _v4 = _mm_set1_ps(4.f); @@ -4741,13 +4567,6 @@ static inline void conv3x3s1_winograd43_transform_output_tile(const Mat& top_til for (int m = 0; m < 6; m++) { - const float* r0 = top_tile.depth(m * 6).row(jj) + ii; - const float* r1 = top_tile.depth(m * 6 + 1).row(jj) + ii; - const float* r2 = top_tile.depth(m * 6 + 2).row(jj) + ii; - const float* r3 = top_tile.depth(m * 6 + 3).row(jj) + ii; - const float* r4 = top_tile.depth(m * 6 + 4).row(jj) + ii; - const float* r5 = top_tile.depth(m * 6 + 5).row(jj) + ii; - __m128 _r0 = _mm_load_ps(r0); __m128 _r1 = _mm_load_ps(r1); __m128 _r2 = _mm_load_ps(r2); @@ -4775,7 +4594,17 @@ static inline void conv3x3s1_winograd43_transform_output_tile(const Mat& top_til _mm_store_ps(tmp[2][m], _tmp2); _mm_store_ps(tmp[3][m], _tmp3); #endif + + r0 += max_jj * 6 * 4; + r1 += max_jj * 6 * 4; + r2 += max_jj * 6 * 4; + r3 += max_jj * 6 * 4; + r4 += max_jj * 6 * 4; + r5 += max_jj * 6 * 4; } + + float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + for (int m = 0; m < 4; m++) { if (ti * 4 + m >= outh) @@ -4865,40 +4694,42 @@ static inline void conv3x3s1_winograd43_transform_output_tile(const Mat& top_til float bias0 = biasptr ? biasptr[i + ii] : 0.f; float bias1 = biasptr ? biasptr[i + ii + 1] : 0.f; + float tmp[4][6][2]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; - float tmp[4][6][2]; - - float* outptr0 = top_blob.channel(i + ii).row(ti * 4) + (tj * 4); + const float* r0 = (const float*)top_tile + ii * max_jj * 36 + jj * 2; + const float* r1 = r0 + max_jj * 2; + const float* r2 = r0 + max_jj * 2 * 2; + const float* r3 = r0 + max_jj * 2 * 3; + const float* r4 = r0 + max_jj * 2 * 4; + const float* r5 = r0 + max_jj * 2 * 5; for (int m = 0; m < 6; m++) { - float r00 = top_tile.depth(m * 6).row(jj)[ii]; - float r01 = top_tile.depth(m * 6).row(jj)[ii + 1]; - float r10 = top_tile.depth(m * 6 + 1).row(jj)[ii]; - float r11 = top_tile.depth(m * 6 + 1).row(jj)[ii + 1]; - float r20 = top_tile.depth(m * 6 + 2).row(jj)[ii]; - float r21 = top_tile.depth(m * 6 + 2).row(jj)[ii + 1]; - float r30 = top_tile.depth(m * 6 + 3).row(jj)[ii]; - float r31 = top_tile.depth(m * 6 + 3).row(jj)[ii + 1]; - float r40 = top_tile.depth(m * 6 + 4).row(jj)[ii]; - float r41 = top_tile.depth(m * 6 + 4).row(jj)[ii + 1]; - float r50 = top_tile.depth(m * 6 + 5).row(jj)[ii]; - float r51 = top_tile.depth(m * 6 + 5).row(jj)[ii + 1]; - - tmp[0][m][0] = r00 + r10 + r20 + r30 + r40; - tmp[0][m][1] = r01 + r11 + r21 + r31 + r41; - tmp[1][m][0] = r10 - r20 + r30 * 2.f - r40 * 2.f; - tmp[1][m][1] = r11 - r21 + r31 * 2.f - r41 * 2.f; - tmp[2][m][0] = r10 + r20 + r30 * 4.f + r40 * 4.f; - tmp[2][m][1] = r11 + r21 + r31 * 4.f + r41 * 4.f; - tmp[3][m][0] = r10 - r20 + r30 * 8.f - r40 * 8.f + r50; - tmp[3][m][1] = r11 - r21 + r31 * 8.f - r41 * 8.f + r51; + tmp[0][m][0] = r0[0] + r1[0] + r2[0] + r3[0] + r4[0]; + tmp[0][m][1] = r0[1] + r1[1] + r2[1] + r3[1] + r4[1]; + tmp[1][m][0] = r1[0] - r2[0] + r3[0] * 2.f - r4[0] * 2.f; + tmp[1][m][1] = r1[1] - r2[1] + r3[1] * 2.f - r4[1] * 2.f; + tmp[2][m][0] = r1[0] + r2[0] + r3[0] * 4.f + r4[0] * 4.f; + tmp[2][m][1] = r1[1] + r2[1] + r3[1] * 4.f + r4[1] * 4.f; + tmp[3][m][0] = r1[0] - r2[0] + r3[0] * 8.f - r4[0] * 8.f + r5[0]; + tmp[3][m][1] = r1[1] - r2[1] + r3[1] * 8.f - r4[1] * 8.f + r5[1]; + + r0 += max_jj * 6 * 2; + r1 += max_jj * 6 * 2; + r2 += max_jj * 6 * 2; + r3 += max_jj * 6 * 2; + r4 += max_jj * 6 * 2; + r5 += max_jj * 6 * 2; } + + float* outptr0 = top_blob.channel(i + ii).row(ti * 4) + (tj * 4); + for (int m = 0; m < 4; m++) { if (ti * 4 + m >= outh) @@ -4957,30 +4788,38 @@ static inline void conv3x3s1_winograd43_transform_output_tile(const Mat& top_til { float bias0 = biasptr ? biasptr[i + ii] : 0.f; + float tmp[4][6]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; - float tmp[4][6]; - - float* outptr0 = top_blob.channel(i + ii).row(ti * 4) + (tj * 4); + const float* r0 = (const float*)top_tile + ii * max_jj * 36 + jj; + const float* r1 = r0 + max_jj; + const float* r2 = r0 + max_jj * 2; + const float* r3 = r0 + max_jj * 3; + const float* r4 = r0 + max_jj * 4; + const float* r5 = r0 + max_jj * 5; for (int m = 0; m < 6; m++) { - float r0 = top_tile.depth(m * 6).row(jj)[ii]; - float r1 = top_tile.depth(m * 6 + 1).row(jj)[ii]; - float r2 = top_tile.depth(m * 6 + 2).row(jj)[ii]; - float r3 = top_tile.depth(m * 6 + 3).row(jj)[ii]; - float r4 = top_tile.depth(m * 6 + 4).row(jj)[ii]; - float r5 = top_tile.depth(m * 6 + 5).row(jj)[ii]; - - tmp[0][m] = r0 + r1 + r2 + r3 + r4; - tmp[1][m] = r1 - r2 + r3 * 2.f - r4 * 2.f; - tmp[2][m] = r1 + r2 + r3 * 4.f + r4 * 4.f; - tmp[3][m] = r1 - r2 + r3 * 8.f - r4 * 8.f + r5; + tmp[0][m] = r0[0] + r1[0] + r2[0] + r3[0] + r4[0]; + tmp[1][m] = r1[0] - r2[0] + r3[0] * 2.f - r4[0] * 2.f; + tmp[2][m] = r1[0] + r2[0] + r3[0] * 4.f + r4[0] * 4.f; + tmp[3][m] = r1[0] - r2[0] + r3[0] * 8.f - r4[0] * 8.f + r5[0]; + + r0 += max_jj * 6; + r1 += max_jj * 6; + r2 += max_jj * 6; + r3 += max_jj * 6; + r4 += max_jj * 6; + r5 += max_jj * 6; } + + float* outptr0 = top_blob.channel(i + ii).row(ti * 4) + (tj * 4); + for (int m = 0; m < 4; m++) { if (ti * 4 + m >= outh) @@ -5012,7 +4851,7 @@ static inline void conv3x3s1_winograd43_transform_output_tile(const Mat& top_til } } -static void conv3x3s1_winograd43(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, const Mat& bias, const Option& opt) +static void conv3x3s1_winograd43(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, const Mat& bias, int nT, const Option& opt) { int outw = top_blob.w; int outh = top_blob.h; @@ -5029,58 +4868,76 @@ static void conv3x3s1_winograd43(const Mat& bottom_blob, Mat& top_blob, const Ma // NCNN_LOGE("conv3x3s1_winograd43 %d %d %d", M, N, K); - int nT = opt.num_threads; - int TILE_M, TILE_N, TILE_K; get_optimal_tile_mnk(M, N, K, TILE_M, TILE_N, TILE_K, nT); const int nn_M = (M + TILE_M - 1) / TILE_M; const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; // NCNN_LOGE("TILE M/N/K = %d %d %d -> %d %d %d", M, N, K, TILE_M, TILE_N, TILE_K); - Mat B_tileX(B * TILE_N * TILE_K, 1, nT, 4u, opt.blob_allocator); Mat BT(TILE_K * TILE_N, B, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 4u, opt.blob_allocator); - #pragma omp parallel for num_threads(nT) - for (int ppj = 0; ppj < nn_N; ppj++) + const int nn_NK = nn_N * nn_K; + + if (nT > 1 && nn_NK < nT) { - const int j = ppj * TILE_N; + Mat B_tile(TILE_N * B * TILE_K, 4u, opt.blob_allocator); - Mat B_tile = B_tileX.channel(get_omp_thread_num()); + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; - const int max_jj = std::min((N - j), TILE_N); + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; - for (int k = 0; k < K; k += TILE_K) - { + const int max_jj = std::min((N - j), TILE_N); const int max_kk = std::min((K - k), TILE_K); // transform input - conv3x3s1_winograd43_transform_input_tile(bottom_blob, B_tile, j, max_jj, k, max_kk); + conv3x3s1_winograd43_transform_input_tile(bottom_blob, B_tile, j, max_jj, k, max_kk, nT); Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); - transpose_pack_B_tile(B_tile, BT_tile, B, max_jj, max_kk); + transpose_pack_B_tile(B_tile, BT_tile, B, max_jj, max_kk, nT); } } - - Mat tmpX; - if (TILE_K < K) + else { - tmpX.create(TILE_M * TILE_N, B, nT, 4u, opt.blob_allocator); + Mat B_tileX(TILE_N * B * TILE_K, 1, nT, 4u, opt.blob_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat B_tile = B_tileX.channel(get_omp_thread_num()); + + // transform input + conv3x3s1_winograd43_transform_input_tile(bottom_blob, B_tile, j, max_jj, k, max_kk, 1); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile(B_tile, BT_tile, B, max_jj, max_kk, 1); + } } - Mat top_tileX(TILE_M, TILE_N, B, nT, 4u, opt.blob_allocator); + Mat top_tileX(TILE_N * B * TILE_M, 1, nT, 4u, opt.blob_allocator); #pragma omp parallel for num_threads(nT) for (int ppj = 0; ppj < nn_M; ppj++) { const int i = ppj * TILE_M; - Mat tmp; - if (K > TILE_K) - tmp = tmpX.channel(get_omp_thread_num()); - Mat top_tile = top_tileX.channel(get_omp_thread_num()); const int max_ii = std::min((M - i), TILE_M); @@ -5097,9 +4954,7 @@ static void conv3x3s1_winograd43(const Mat& bottom_blob, Mat& top_blob, const Ma const Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); - bool k_end = k + TILE_K >= K; - - gemm_transB_packed_tile(AT_tile, BT_tile, top_tile, tmp, B, max_ii, max_jj, k, max_kk, k_end); + gemm_transB_packed_tile(AT_tile, BT_tile, top_tile, B, max_ii, max_jj, k, max_kk); } // transform output @@ -5221,7 +5076,7 @@ static void conv3x3s1_winograd63_transform_kernel(const Mat& kernel, Mat& AT, in } } -static inline void conv3x3s1_winograd63_transform_input_tile(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +static inline void conv3x3s1_winograd63_transform_input_tile(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int nT) { // const float itm[8][8] = { // {1.0f, 0.0f,-5.25f, 0.00f, 5.25f, 0.00f,-1.0f, 0.0f}, @@ -5241,27 +5096,30 @@ static inline void conv3x3s1_winograd63_transform_input_tile(const Mat& bottom_b const int w_tiles = (w + 3) / 6; - float* ptmp = B; - - int kk = 0; + int nn_max_kk = 0; + int remain_max_kk_start = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ - for (; kk + 15 < max_kk; kk += 16) + nn_max_kk = max_kk / 16; + #pragma omp parallel for num_threads(nT) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) { + const int kk = ppkk * 16; + +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + float tmp[8][8][16]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(64)) -#else - __attribute__((aligned(64))) -#endif - float tmp[8][8][16]; - const float* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 6) + (tj * 6) * elempack; for (int m = 0; m < 8; m++) @@ -5428,6 +5286,16 @@ static inline void conv3x3s1_winograd63_transform_input_tile(const Mat& bottom_b r0 += w * elempack; } + + float* p0 = (float*)B + kk * max_jj * 64 + jj * 16; + float* p1 = p0 + max_jj * 16; + float* p2 = p0 + max_jj * 16 * 2; + float* p3 = p0 + max_jj * 16 * 3; + float* p4 = p0 + max_jj * 16 * 4; + float* p5 = p0 + max_jj * 16 * 5; + float* p6 = p0 + max_jj * 16 * 6; + float* p7 = p0 + max_jj * 16 * 7; + for (int m = 0; m < 8; m++) { __m512 _r0 = _mm512_load_ps(tmp[m][0]); @@ -5464,34 +5332,49 @@ static inline void conv3x3s1_winograd63_transform_input_tile(const Mat& bottom_b __m512 _tmp6 = _mm512_sub_ps(_tmp56a, _tmp56b); __m512 _tmp7 = _mm512_fmadd_ps(_v5_25, _mm512_sub_ps(_r3, _r5), _mm512_sub_ps(_r7, _r1)); - _mm512_store_ps(ptmp, _tmp0); - _mm512_store_ps(ptmp + 16, _tmp1); - _mm512_store_ps(ptmp + 32, _tmp2); - _mm512_store_ps(ptmp + 48, _tmp3); - _mm512_store_ps(ptmp + 64, _tmp4); - _mm512_store_ps(ptmp + 80, _tmp5); - _mm512_store_ps(ptmp + 96, _tmp6); - _mm512_store_ps(ptmp + 112, _tmp7); - ptmp += 128; + _mm512_store_ps(p0, _tmp0); + _mm512_store_ps(p1, _tmp1); + _mm512_store_ps(p2, _tmp2); + _mm512_store_ps(p3, _tmp3); + _mm512_store_ps(p4, _tmp4); + _mm512_store_ps(p5, _tmp5); + _mm512_store_ps(p6, _tmp6); + _mm512_store_ps(p7, _tmp7); + + p0 += max_jj * 8 * 16; + p1 += max_jj * 8 * 16; + p2 += max_jj * 8 * 16; + p3 += max_jj * 8 * 16; + p4 += max_jj * 8 * 16; + p5 += max_jj * 8 * 16; + p6 += max_jj * 8 * 16; + p7 += max_jj * 8 * 16; } } } + remain_max_kk_start += nn_max_kk * 16; + nn_max_kk = (max_kk - remain_max_kk_start) / 8; +#else // __AVX512F__ + nn_max_kk = (max_kk - remain_max_kk_start) / 8; + #pragma omp parallel for num_threads(nT) #endif // __AVX512F__ - for (; kk + 7 < max_kk; kk += 8) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) { + const int kk = remain_max_kk_start + ppkk * 8; + +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + float tmp[8][8][8]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(32)) -#else - __attribute__((aligned(32))) -#endif - float tmp[8][8][8]; - const float* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 6) + (tj * 6) * elempack; for (int m = 0; m < 8; m++) @@ -5626,6 +5509,16 @@ static inline void conv3x3s1_winograd63_transform_input_tile(const Mat& bottom_b r0 += w * elempack; } + + float* p0 = (float*)B + kk * max_jj * 64 + jj * 8; + float* p1 = p0 + max_jj * 8; + float* p2 = p0 + max_jj * 8 * 2; + float* p3 = p0 + max_jj * 8 * 3; + float* p4 = p0 + max_jj * 8 * 4; + float* p5 = p0 + max_jj * 8 * 5; + float* p6 = p0 + max_jj * 8 * 6; + float* p7 = p0 + max_jj * 8 * 7; + for (int m = 0; m < 8; m++) { #if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) @@ -5673,34 +5566,49 @@ static inline void conv3x3s1_winograd63_transform_input_tile(const Mat& bottom_b __m256 _tmp6 = _mm256_sub_ps(_tmp56a, _tmp56b); __m256 _tmp7 = _mm256_comp_fmadd_ps(_v5_25, _mm256_sub_ps(_r3, _r5), _mm256_sub_ps(_r7, _r1)); - _mm256_store_ps(ptmp, _tmp0); - _mm256_store_ps(ptmp + 8, _tmp1); - _mm256_store_ps(ptmp + 16, _tmp2); - _mm256_store_ps(ptmp + 24, _tmp3); - _mm256_store_ps(ptmp + 32, _tmp4); - _mm256_store_ps(ptmp + 40, _tmp5); - _mm256_store_ps(ptmp + 48, _tmp6); - _mm256_store_ps(ptmp + 56, _tmp7); - ptmp += 64; + _mm256_store_ps(p0, _tmp0); + _mm256_store_ps(p1, _tmp1); + _mm256_store_ps(p2, _tmp2); + _mm256_store_ps(p3, _tmp3); + _mm256_store_ps(p4, _tmp4); + _mm256_store_ps(p5, _tmp5); + _mm256_store_ps(p6, _tmp6); + _mm256_store_ps(p7, _tmp7); + + p0 += max_jj * 8 * 8; + p1 += max_jj * 8 * 8; + p2 += max_jj * 8 * 8; + p3 += max_jj * 8 * 8; + p4 += max_jj * 8 * 8; + p5 += max_jj * 8 * 8; + p6 += max_jj * 8 * 8; + p7 += max_jj * 8 * 8; } } } + remain_max_kk_start += nn_max_kk * 8; + nn_max_kk = (max_kk - remain_max_kk_start) / 4; +#else // __AVX__ + nn_max_kk = (max_kk - remain_max_kk_start) / 4; + #pragma omp parallel for num_threads(nT) #endif // __AVX__ - for (; kk + 3 < max_kk; kk += 4) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) { + const int kk = remain_max_kk_start + ppkk * 4; + +#ifdef _MSC_VER + __declspec(align(16)) +#else + __attribute__((aligned(16))) +#endif + float tmp[8][8][4]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(16)) -#else - __attribute__((aligned(16))) -#endif - float tmp[8][8][4]; - const float* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 6) + (tj * 6) * elempack; for (int m = 0; m < 8; m++) @@ -5808,6 +5716,16 @@ static inline void conv3x3s1_winograd63_transform_input_tile(const Mat& bottom_b r0 += w * elempack; } + + float* p0 = (float*)B + kk * max_jj * 64 + jj * 4; + float* p1 = p0 + max_jj * 4; + float* p2 = p0 + max_jj * 4 * 2; + float* p3 = p0 + max_jj * 4 * 3; + float* p4 = p0 + max_jj * 4 * 4; + float* p5 = p0 + max_jj * 4 * 5; + float* p6 = p0 + max_jj * 4 * 6; + float* p7 = p0 + max_jj * 4 * 7; + for (int m = 0; m < 8; m++) { #if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) @@ -5855,29 +5773,44 @@ static inline void conv3x3s1_winograd63_transform_input_tile(const Mat& bottom_b __m128 _tmp6 = _mm_sub_ps(_tmp56a, _tmp56b); __m128 _tmp7 = _mm_comp_fmadd_ps(_v5_25, _mm_sub_ps(_r3, _r5), _mm_sub_ps(_r7, _r1)); - _mm_store_ps(ptmp, _tmp0); - _mm_store_ps(ptmp + 4, _tmp1); - _mm_store_ps(ptmp + 8, _tmp2); - _mm_store_ps(ptmp + 12, _tmp3); - _mm_store_ps(ptmp + 16, _tmp4); - _mm_store_ps(ptmp + 20, _tmp5); - _mm_store_ps(ptmp + 24, _tmp6); - _mm_store_ps(ptmp + 28, _tmp7); - ptmp += 32; + _mm_store_ps(p0, _tmp0); + _mm_store_ps(p1, _tmp1); + _mm_store_ps(p2, _tmp2); + _mm_store_ps(p3, _tmp3); + _mm_store_ps(p4, _tmp4); + _mm_store_ps(p5, _tmp5); + _mm_store_ps(p6, _tmp6); + _mm_store_ps(p7, _tmp7); + + p0 += max_jj * 8 * 4; + p1 += max_jj * 8 * 4; + p2 += max_jj * 8 * 4; + p3 += max_jj * 8 * 4; + p4 += max_jj * 8 * 4; + p5 += max_jj * 8 * 4; + p6 += max_jj * 8 * 4; + p7 += max_jj * 8 * 4; } } } + remain_max_kk_start += nn_max_kk * 4; + nn_max_kk = (max_kk - remain_max_kk_start) / 2; +#else // __SSE2__ + nn_max_kk = (max_kk - remain_max_kk_start) / 2; + #pragma omp parallel for num_threads(nT) #endif // __SSE2__ - for (; kk + 1 < max_kk; kk += 2) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) { + const int kk = remain_max_kk_start + ppkk * 2; + + float tmp[8][8][2]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; - float tmp[8][8][2]; - const float* r0 = bottom_blob.channel(k + kk).row(ti * 6) + (tj * 6); for (int m = 0; m < 8; m++) @@ -5977,6 +5910,16 @@ static inline void conv3x3s1_winograd63_transform_input_tile(const Mat& bottom_b r0 += w; } + + float* p0 = (float*)B + kk * max_jj * 64 + jj * 2; + float* p1 = p0 + max_jj * 2; + float* p2 = p0 + max_jj * 2 * 2; + float* p3 = p0 + max_jj * 2 * 3; + float* p4 = p0 + max_jj * 2 * 4; + float* p5 = p0 + max_jj * 2 * 5; + float* p6 = p0 + max_jj * 2 * 6; + float* p7 = p0 + max_jj * 2 * 7; + for (int m = 0; m < 8; m++) { float r00 = tmp[m][0][0]; @@ -6009,36 +5952,45 @@ static inline void conv3x3s1_winograd63_transform_input_tile(const Mat& bottom_b float tmp56b0 = r10 * 2.f - r30 * 2.5f + r50 * 0.5f; float tmp56b1 = r11 * 2.f - r31 * 2.5f + r51 * 0.5f; - ptmp[0] = r00 - r60 + (r40 - r20) * 5.25f; - ptmp[1] = r01 - r61 + (r41 - r21) * 5.25f; - ptmp[2] = tmp12a0 + tmp12b0; - ptmp[3] = tmp12a1 + tmp12b1; - ptmp[4] = tmp12a0 - tmp12b0; - ptmp[5] = tmp12a1 - tmp12b1; - ptmp[6] = tmp34a0 + tmp34b0; - ptmp[7] = tmp34a1 + tmp34b1; - ptmp[8] = tmp34a0 - tmp34b0; - ptmp[9] = tmp34a1 - tmp34b1; - ptmp[10] = tmp56a0 + tmp56b0; - ptmp[11] = tmp56a1 + tmp56b1; - ptmp[12] = tmp56a0 - tmp56b0; - ptmp[13] = tmp56a1 - tmp56b1; - ptmp[14] = r70 - r10 + (r30 - r50) * 5.25f; - ptmp[15] = r71 - r11 + (r31 - r51) * 5.25f; - ptmp += 16; + p0[0] = r00 - r60 + (r40 - r20) * 5.25f; + p0[1] = r01 - r61 + (r41 - r21) * 5.25f; + p1[0] = tmp12a0 + tmp12b0; + p1[1] = tmp12a1 + tmp12b1; + p2[0] = tmp12a0 - tmp12b0; + p2[1] = tmp12a1 - tmp12b1; + p3[0] = tmp34a0 + tmp34b0; + p3[1] = tmp34a1 + tmp34b1; + p4[0] = tmp34a0 - tmp34b0; + p4[1] = tmp34a1 - tmp34b1; + p5[0] = tmp56a0 + tmp56b0; + p5[1] = tmp56a1 + tmp56b1; + p6[0] = tmp56a0 - tmp56b0; + p6[1] = tmp56a1 - tmp56b1; + p7[0] = r70 - r10 + (r30 - r50) * 5.25f; + p7[1] = r71 - r11 + (r31 - r51) * 5.25f; + + p0 += max_jj * 8 * 2; + p1 += max_jj * 8 * 2; + p2 += max_jj * 8 * 2; + p3 += max_jj * 8 * 2; + p4 += max_jj * 8 * 2; + p5 += max_jj * 8 * 2; + p6 += max_jj * 8 * 2; + p7 += max_jj * 8 * 2; } } } - for (; kk < max_kk; kk++) + remain_max_kk_start += nn_max_kk * 2; + for (int kk = remain_max_kk_start; kk < max_kk; kk++) { + float tmp[8][8]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; - float tmp[8][8]; - const float* r0123 = bottom_blob.channel(k + kk).row(ti * 6) + (tj * 6); for (int m = 0; m < 8; m++) @@ -6085,6 +6037,16 @@ static inline void conv3x3s1_winograd63_transform_input_tile(const Mat& bottom_b r0123 += w; } + + float* p0 = (float*)B + kk * max_jj * 64 + jj; + float* p1 = p0 + max_jj; + float* p2 = p0 + max_jj * 2; + float* p3 = p0 + max_jj * 3; + float* p4 = p0 + max_jj * 4; + float* p5 = p0 + max_jj * 5; + float* p6 = p0 + max_jj * 6; + float* p7 = p0 + max_jj * 7; + for (int m = 0; m < 8; m++) { float r0 = tmp[m][0]; @@ -6103,15 +6065,23 @@ static inline void conv3x3s1_winograd63_transform_input_tile(const Mat& bottom_b float tmp56a = r2 * 4.f - r4 * 5.f + r6; float tmp56b = r1 * 2.f - r3 * 2.5f + r5 * 0.5f; - ptmp[0] = r0 - r6 + (r4 - r2) * 5.25f; - ptmp[1] = tmp12a + tmp12b; - ptmp[2] = tmp12a - tmp12b; - ptmp[3] = tmp34a + tmp34b; - ptmp[4] = tmp34a - tmp34b; - ptmp[5] = tmp56a + tmp56b; - ptmp[6] = tmp56a - tmp56b; - ptmp[7] = r7 - r1 + (r3 - r5) * 5.25f; - ptmp += 8; + p0[0] = r0 - r6 + (r4 - r2) * 5.25f; + p1[0] = tmp12a + tmp12b; + p2[0] = tmp12a - tmp12b; + p3[0] = tmp34a + tmp34b; + p4[0] = tmp34a - tmp34b; + p5[0] = tmp56a + tmp56b; + p6[0] = tmp56a - tmp56b; + p7[0] = r7 - r1 + (r3 - r5) * 5.25f; + + p0 += max_jj * 8; + p1 += max_jj * 8; + p2 += max_jj * 8; + p3 += max_jj * 8; + p4 += max_jj * 8; + p5 += max_jj * 8; + p6 += max_jj * 8; + p7 += max_jj * 8; } } } @@ -6145,20 +6115,27 @@ static inline void conv3x3s1_winograd63_transform_output_tile(const Mat& top_til { __m512 _bias0 = biasptr ? _mm512_loadu_ps(biasptr + i + ii) : _mm512_setzero_ps(); +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + float tmp[6][8][16]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(64)) -#else - __attribute__((aligned(64))) -#endif - float tmp[6][8][16]; - - float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 6) + (tj * 6) * out_elempack; + const float* r0 = (const float*)top_tile + ii * max_jj * 64 + jj * 16; + const float* r1 = r0 + max_jj * 16; + const float* r2 = r0 + max_jj * 16 * 2; + const float* r3 = r0 + max_jj * 16 * 3; + const float* r4 = r0 + max_jj * 16 * 4; + const float* r5 = r0 + max_jj * 16 * 5; + const float* r6 = r0 + max_jj * 16 * 6; + const float* r7 = r0 + max_jj * 16 * 7; __m512 _v32 = _mm512_set1_ps(32.f); __m512 _v16 = _mm512_set1_ps(16.f); @@ -6168,15 +6145,6 @@ static inline void conv3x3s1_winograd63_transform_output_tile(const Mat& top_til for (int m = 0; m < 8; m++) { - const float* r0 = top_tile.depth(m * 8).row(jj) + ii; - const float* r1 = top_tile.depth(m * 8 + 1).row(jj) + ii; - const float* r2 = top_tile.depth(m * 8 + 2).row(jj) + ii; - const float* r3 = top_tile.depth(m * 8 + 3).row(jj) + ii; - const float* r4 = top_tile.depth(m * 8 + 4).row(jj) + ii; - const float* r5 = top_tile.depth(m * 8 + 5).row(jj) + ii; - const float* r6 = top_tile.depth(m * 8 + 6).row(jj) + ii; - const float* r7 = top_tile.depth(m * 8 + 7).row(jj) + ii; - __m512 _r0 = _mm512_load_ps(r0); __m512 _r1 = _mm512_load_ps(r1); __m512 _r2 = _mm512_load_ps(r2); @@ -6205,7 +6173,19 @@ static inline void conv3x3s1_winograd63_transform_output_tile(const Mat& top_til _mm512_store_ps(tmp[3][m], _tmp3); _mm512_store_ps(tmp[4][m], _tmp4); _mm512_store_ps(tmp[5][m], _tmp5); + + r0 += max_jj * 8 * 16; + r1 += max_jj * 8 * 16; + r2 += max_jj * 8 * 16; + r3 += max_jj * 8 * 16; + r4 += max_jj * 8 * 16; + r5 += max_jj * 8 * 16; + r6 += max_jj * 8 * 16; + r7 += max_jj * 8 * 16; } + + float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 6) + (tj * 6) * out_elempack; + for (int m = 0; m < 6; m++) { if (ti * 6 + m >= outh) @@ -6473,20 +6453,27 @@ static inline void conv3x3s1_winograd63_transform_output_tile(const Mat& top_til { __m256 _bias0 = biasptr ? _mm256_loadu_ps(biasptr + i + ii) : _mm256_setzero_ps(); +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + float tmp[6][8][8]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(32)) -#else - __attribute__((aligned(32))) -#endif - float tmp[6][8][8]; - - float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 6) + (tj * 6) * out_elempack; + const float* r0 = (const float*)top_tile + ii * max_jj * 64 + jj * 8; + const float* r1 = r0 + max_jj * 8; + const float* r2 = r0 + max_jj * 8 * 2; + const float* r3 = r0 + max_jj * 8 * 3; + const float* r4 = r0 + max_jj * 8 * 4; + const float* r5 = r0 + max_jj * 8 * 5; + const float* r6 = r0 + max_jj * 8 * 6; + const float* r7 = r0 + max_jj * 8 * 7; __m256 _v32 = _mm256_set1_ps(32.f); __m256 _v16 = _mm256_set1_ps(16.f); @@ -6496,15 +6483,6 @@ static inline void conv3x3s1_winograd63_transform_output_tile(const Mat& top_til for (int m = 0; m < 8; m++) { - const float* r0 = top_tile.depth(m * 8).row(jj) + ii; - const float* r1 = top_tile.depth(m * 8 + 1).row(jj) + ii; - const float* r2 = top_tile.depth(m * 8 + 2).row(jj) + ii; - const float* r3 = top_tile.depth(m * 8 + 3).row(jj) + ii; - const float* r4 = top_tile.depth(m * 8 + 4).row(jj) + ii; - const float* r5 = top_tile.depth(m * 8 + 5).row(jj) + ii; - const float* r6 = top_tile.depth(m * 8 + 6).row(jj) + ii; - const float* r7 = top_tile.depth(m * 8 + 7).row(jj) + ii; - __m256 _r0 = _mm256_load_ps(r0); __m256 _r1 = _mm256_load_ps(r1); __m256 _r2 = _mm256_load_ps(r2); @@ -6542,7 +6520,19 @@ static inline void conv3x3s1_winograd63_transform_output_tile(const Mat& top_til _mm256_store_ps(tmp[4][m], _tmp4); _mm256_store_ps(tmp[5][m], _tmp5); #endif + + r0 += max_jj * 8 * 8; + r1 += max_jj * 8 * 8; + r2 += max_jj * 8 * 8; + r3 += max_jj * 8 * 8; + r4 += max_jj * 8 * 8; + r5 += max_jj * 8 * 8; + r6 += max_jj * 8 * 8; + r7 += max_jj * 8 * 8; } + + float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 6) + (tj * 6) * out_elempack; + for (int m = 0; m < 6; m++) { if (ti * 6 + m >= outh) @@ -6719,20 +6709,27 @@ static inline void conv3x3s1_winograd63_transform_output_tile(const Mat& top_til { __m128 _bias0 = biasptr ? _mm_loadu_ps(biasptr + i + ii) : _mm_setzero_ps(); +#ifdef _MSC_VER + __declspec(align(16)) +#else + __attribute__((aligned(16))) +#endif + float tmp[6][8][4]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; -#ifdef _MSC_VER - __declspec(align(16)) -#else - __attribute__((aligned(16))) -#endif - float tmp[6][8][4]; - - float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 6) + (tj * 6) * out_elempack; + const float* r0 = (const float*)top_tile + ii * max_jj * 64 + jj * 4; + const float* r1 = r0 + max_jj * 4; + const float* r2 = r0 + max_jj * 4 * 2; + const float* r3 = r0 + max_jj * 4 * 3; + const float* r4 = r0 + max_jj * 4 * 4; + const float* r5 = r0 + max_jj * 4 * 5; + const float* r6 = r0 + max_jj * 4 * 6; + const float* r7 = r0 + max_jj * 4 * 7; __m128 _v32 = _mm_set1_ps(32.f); __m128 _v16 = _mm_set1_ps(16.f); @@ -6742,15 +6739,6 @@ static inline void conv3x3s1_winograd63_transform_output_tile(const Mat& top_til for (int m = 0; m < 8; m++) { - const float* r0 = top_tile.depth(m * 8).row(jj) + ii; - const float* r1 = top_tile.depth(m * 8 + 1).row(jj) + ii; - const float* r2 = top_tile.depth(m * 8 + 2).row(jj) + ii; - const float* r3 = top_tile.depth(m * 8 + 3).row(jj) + ii; - const float* r4 = top_tile.depth(m * 8 + 4).row(jj) + ii; - const float* r5 = top_tile.depth(m * 8 + 5).row(jj) + ii; - const float* r6 = top_tile.depth(m * 8 + 6).row(jj) + ii; - const float* r7 = top_tile.depth(m * 8 + 7).row(jj) + ii; - __m128 _r0 = _mm_load_ps(r0); __m128 _r1 = _mm_load_ps(r1); __m128 _r2 = _mm_load_ps(r2); @@ -6788,7 +6776,19 @@ static inline void conv3x3s1_winograd63_transform_output_tile(const Mat& top_til _mm_store_ps(tmp[4][m], _tmp4); _mm_store_ps(tmp[5][m], _tmp5); #endif + + r0 += max_jj * 8 * 4; + r1 += max_jj * 8 * 4; + r2 += max_jj * 8 * 4; + r3 += max_jj * 8 * 4; + r4 += max_jj * 8 * 4; + r5 += max_jj * 8 * 4; + r6 += max_jj * 8 * 4; + r7 += max_jj * 8 * 4; } + + float* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 6) + (tj * 6) * out_elempack; + for (int m = 0; m < 6; m++) { if (ti * 6 + m >= outh) @@ -6906,50 +6906,40 @@ static inline void conv3x3s1_winograd63_transform_output_tile(const Mat& top_til float bias0 = biasptr ? biasptr[i + ii] : 0.f; float bias1 = biasptr ? biasptr[i + ii + 1] : 0.f; + float tmp[6][8][2]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; - float tmp[6][8][2]; - - float* outptr0 = top_blob.channel(i + ii).row(ti * 6) + (tj * 6); + const float* r0 = (const float*)top_tile + ii * max_jj * 64 + jj * 2; + const float* r1 = r0 + max_jj * 2; + const float* r2 = r0 + max_jj * 2 * 2; + const float* r3 = r0 + max_jj * 2 * 3; + const float* r4 = r0 + max_jj * 2 * 4; + const float* r5 = r0 + max_jj * 2 * 5; + const float* r6 = r0 + max_jj * 2 * 6; + const float* r7 = r0 + max_jj * 2 * 7; for (int m = 0; m < 8; m++) { - float r00 = top_tile.depth(m * 8).row(jj)[ii]; - float r01 = top_tile.depth(m * 8).row(jj)[ii + 1]; - float r10 = top_tile.depth(m * 8 + 1).row(jj)[ii]; - float r11 = top_tile.depth(m * 8 + 1).row(jj)[ii + 1]; - float r20 = top_tile.depth(m * 8 + 2).row(jj)[ii]; - float r21 = top_tile.depth(m * 8 + 2).row(jj)[ii + 1]; - float r30 = top_tile.depth(m * 8 + 3).row(jj)[ii]; - float r31 = top_tile.depth(m * 8 + 3).row(jj)[ii + 1]; - float r40 = top_tile.depth(m * 8 + 4).row(jj)[ii]; - float r41 = top_tile.depth(m * 8 + 4).row(jj)[ii + 1]; - float r50 = top_tile.depth(m * 8 + 5).row(jj)[ii]; - float r51 = top_tile.depth(m * 8 + 5).row(jj)[ii + 1]; - float r60 = top_tile.depth(m * 8 + 6).row(jj)[ii]; - float r61 = top_tile.depth(m * 8 + 6).row(jj)[ii + 1]; - float r70 = top_tile.depth(m * 8 + 7).row(jj)[ii]; - float r71 = top_tile.depth(m * 8 + 7).row(jj)[ii + 1]; - - float tmp024a0 = r10 + r20; - float tmp024a1 = r11 + r21; - float tmp135a0 = r10 - r20; - float tmp135a1 = r11 - r21; - float tmp024b0 = r30 + r40; - float tmp024b1 = r31 + r41; - float tmp135b0 = r30 - r40; - float tmp135b1 = r31 - r41; - float tmp024c0 = r50 + r60; - float tmp024c1 = r51 + r61; - float tmp135c0 = r50 - r60; - float tmp135c1 = r51 - r61; - - tmp[0][m][0] = r00 + tmp024a0 + tmp024b0 + tmp024c0 * 32; - tmp[0][m][1] = r01 + tmp024a1 + tmp024b1 + tmp024c1 * 32; + float tmp024a0 = r1[0] + r2[0]; + float tmp024a1 = r1[1] + r2[1]; + float tmp135a0 = r1[0] - r2[0]; + float tmp135a1 = r1[1] - r2[1]; + float tmp024b0 = r3[0] + r4[0]; + float tmp024b1 = r3[1] + r4[1]; + float tmp135b0 = r3[0] - r4[0]; + float tmp135b1 = r3[1] - r4[1]; + float tmp024c0 = r5[0] + r6[0]; + float tmp024c1 = r5[1] + r6[1]; + float tmp135c0 = r5[0] - r6[0]; + float tmp135c1 = r5[1] - r6[1]; + + tmp[0][m][0] = r0[0] + tmp024a0 + tmp024b0 + tmp024c0 * 32; + tmp[0][m][1] = r0[1] + tmp024a1 + tmp024b1 + tmp024c1 * 32; tmp[1][m][0] = tmp135a0 + tmp135b0 + tmp135b0 + tmp135c0 * 16; tmp[1][m][1] = tmp135a1 + tmp135b1 + tmp135b1 + tmp135c1 * 16; tmp[2][m][0] = tmp024a0 + tmp024b0 * 4 + tmp024c0 * 8; @@ -6958,9 +6948,21 @@ static inline void conv3x3s1_winograd63_transform_output_tile(const Mat& top_til tmp[3][m][1] = tmp135a1 + tmp135b1 * 8 + tmp135c1 * 4; tmp[4][m][0] = tmp024a0 + tmp024b0 * 16 + tmp024c0 + tmp024c0; tmp[4][m][1] = tmp024a1 + tmp024b1 * 16 + tmp024c1 + tmp024c1; - tmp[5][m][0] = r70 + tmp135a0 + tmp135b0 * 32 + tmp135c0; - tmp[5][m][1] = r71 + tmp135a1 + tmp135b1 * 32 + tmp135c1; + tmp[5][m][0] = r7[0] + tmp135a0 + tmp135b0 * 32 + tmp135c0; + tmp[5][m][1] = r7[1] + tmp135a1 + tmp135b1 * 32 + tmp135c1; + + r0 += max_jj * 8 * 2; + r1 += max_jj * 8 * 2; + r2 += max_jj * 8 * 2; + r3 += max_jj * 8 * 2; + r4 += max_jj * 8 * 2; + r5 += max_jj * 8 * 2; + r6 += max_jj * 8 * 2; + r7 += max_jj * 8 * 2; } + + float* outptr0 = top_blob.channel(i + ii).row(ti * 6) + (tj * 6); + for (int m = 0; m < 6; m++) { if (ti * 6 + m >= outh) @@ -7050,41 +7052,51 @@ static inline void conv3x3s1_winograd63_transform_output_tile(const Mat& top_til { float bias0 = biasptr ? biasptr[i + ii] : 0.f; + float tmp[6][8]; + int jj = 0; for (; jj < max_jj; jj++) { int ti = (j + jj) / w_tiles; int tj = (j + jj) % w_tiles; - float tmp[6][8]; - - float* outptr0 = top_blob.channel(i + ii).row(ti * 6) + (tj * 6); + const float* r0 = (const float*)top_tile + ii * max_jj * 64 + jj; + const float* r1 = r0 + max_jj; + const float* r2 = r0 + max_jj * 2; + const float* r3 = r0 + max_jj * 3; + const float* r4 = r0 + max_jj * 4; + const float* r5 = r0 + max_jj * 5; + const float* r6 = r0 + max_jj * 6; + const float* r7 = r0 + max_jj * 7; for (int m = 0; m < 8; m++) { - float r0 = top_tile.depth(m * 8).row(jj)[ii]; - float r1 = top_tile.depth(m * 8 + 1).row(jj)[ii]; - float r2 = top_tile.depth(m * 8 + 2).row(jj)[ii]; - float r3 = top_tile.depth(m * 8 + 3).row(jj)[ii]; - float r4 = top_tile.depth(m * 8 + 4).row(jj)[ii]; - float r5 = top_tile.depth(m * 8 + 5).row(jj)[ii]; - float r6 = top_tile.depth(m * 8 + 6).row(jj)[ii]; - float r7 = top_tile.depth(m * 8 + 7).row(jj)[ii]; - - float tmp024a = r1 + r2; - float tmp135a = r1 - r2; - float tmp024b = r3 + r4; - float tmp135b = r3 - r4; - float tmp024c = r5 + r6; - float tmp135c = r5 - r6; - - tmp[0][m] = r0 + tmp024a + tmp024b + tmp024c * 32; + float tmp024a = r1[0] + r2[0]; + float tmp135a = r1[0] - r2[0]; + float tmp024b = r3[0] + r4[0]; + float tmp135b = r3[0] - r4[0]; + float tmp024c = r5[0] + r6[0]; + float tmp135c = r5[0] - r6[0]; + + tmp[0][m] = r0[0] + tmp024a + tmp024b + tmp024c * 32; tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16; tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8; tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4; tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c; - tmp[5][m] = r7 + tmp135a + tmp135b * 32 + tmp135c; + tmp[5][m] = r7[0] + tmp135a + tmp135b * 32 + tmp135c; + + r0 += max_jj * 8; + r1 += max_jj * 8; + r2 += max_jj * 8; + r3 += max_jj * 8; + r4 += max_jj * 8; + r5 += max_jj * 8; + r6 += max_jj * 8; + r7 += max_jj * 8; } + + float* outptr0 = top_blob.channel(i + ii).row(ti * 6) + (tj * 6); + for (int m = 0; m < 6; m++) { if (ti * 6 + m >= outh) @@ -7129,7 +7141,7 @@ static inline void conv3x3s1_winograd63_transform_output_tile(const Mat& top_til } } -static void conv3x3s1_winograd63(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, const Mat& bias, const Option& opt) +static void conv3x3s1_winograd63(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, const Mat& bias, int nT, const Option& opt) { int outw = top_blob.w; int outh = top_blob.h; @@ -7146,58 +7158,76 @@ static void conv3x3s1_winograd63(const Mat& bottom_blob, Mat& top_blob, const Ma // NCNN_LOGE("conv3x3s1_winograd63 %d %d %d", M, N, K); - int nT = opt.num_threads; - int TILE_M, TILE_N, TILE_K; get_optimal_tile_mnk(M, N, K, TILE_M, TILE_N, TILE_K, nT); const int nn_M = (M + TILE_M - 1) / TILE_M; const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; // NCNN_LOGE("TILE M/N/K = %d %d %d -> %d %d %d", M, N, K, TILE_M, TILE_N, TILE_K); - Mat B_tileX(B * TILE_N * TILE_K, 1, nT, 4u, opt.blob_allocator); Mat BT(TILE_K * TILE_N, B, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 4u, opt.blob_allocator); - #pragma omp parallel for num_threads(nT) - for (int ppj = 0; ppj < nn_N; ppj++) + const int nn_NK = nn_N * nn_K; + + if (nT > 1 && nn_NK < nT) { - const int j = ppj * TILE_N; + Mat B_tile(TILE_N * B * TILE_K, 4u, opt.blob_allocator); - Mat B_tile = B_tileX.channel(get_omp_thread_num()); + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; - const int max_jj = std::min((N - j), TILE_N); + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; - for (int k = 0; k < K; k += TILE_K) - { + const int max_jj = std::min((N - j), TILE_N); const int max_kk = std::min((K - k), TILE_K); // transform input - conv3x3s1_winograd63_transform_input_tile(bottom_blob, B_tile, j, max_jj, k, max_kk); + conv3x3s1_winograd63_transform_input_tile(bottom_blob, B_tile, j, max_jj, k, max_kk, nT); Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); - transpose_pack_B_tile(B_tile, BT_tile, B, max_jj, max_kk); + transpose_pack_B_tile(B_tile, BT_tile, B, max_jj, max_kk, nT); } } - - Mat tmpX; - if (TILE_K < K) + else { - tmpX.create(TILE_M * TILE_N, B, nT, 4u, opt.blob_allocator); + Mat B_tileX(TILE_N * B * TILE_K, 1, nT, 4u, opt.blob_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat B_tile = B_tileX.channel(get_omp_thread_num()); + + // transform input + conv3x3s1_winograd63_transform_input_tile(bottom_blob, B_tile, j, max_jj, k, max_kk, 1); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile(B_tile, BT_tile, B, max_jj, max_kk, 1); + } } - Mat top_tileX(TILE_M, TILE_N, B, nT, 4u, opt.blob_allocator); + Mat top_tileX(TILE_N * B * TILE_M, 1, nT, 4u, opt.blob_allocator); #pragma omp parallel for num_threads(nT) for (int ppj = 0; ppj < nn_M; ppj++) { const int i = ppj * TILE_M; - Mat tmp; - if (K > TILE_K) - tmp = tmpX.channel(get_omp_thread_num()); - Mat top_tile = top_tileX.channel(get_omp_thread_num()); const int max_ii = std::min((M - i), TILE_M); @@ -7214,9 +7244,7 @@ static void conv3x3s1_winograd63(const Mat& bottom_blob, Mat& top_blob, const Ma const Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); - bool k_end = k + TILE_K >= K; - - gemm_transB_packed_tile(AT_tile, BT_tile, top_tile, tmp, B, max_ii, max_jj, k, max_kk, k_end); + gemm_transB_packed_tile(AT_tile, BT_tile, top_tile, B, max_ii, max_jj, k, max_kk); } // transform output diff --git a/src/layer/x86/convolution_x86.cpp b/src/layer/x86/convolution_x86.cpp index 91d3c748d..cb5fba365 100644 --- a/src/layer/x86/convolution_x86.cpp +++ b/src/layer/x86/convolution_x86.cpp @@ -103,6 +103,7 @@ Convolution_x86::Convolution_x86() #endif // __SSE2__ activation = 0; + nT = 0; convolution_dilation1 = 0; gemm = 0; } @@ -143,12 +144,177 @@ static void convolution_transform_kernel_packed_sse(const Mat& weight_data, Mat& } } +static bool test_prefer_winograd63(int num_input, int num_output, int w, int h) +{ + // winograd selection strategy (profiled on i7-7700 single thread) + + int minwh = std::min(w, h); + + if (num_input >= 64) + { + return false; + } + if (num_input >= 32) + { + if (num_output >= 64) return false; + if (num_output >= 32) return (minwh >= 11 && minwh <= 14) + || (minwh >= 19 && minwh <= 20) + || (minwh >= 23 && minwh <= 44) + || (minwh >= 47 && minwh <= 56) + || (minwh >= 63 && minwh <= 130); + if (num_output >= 16) return (minwh >= 13 && minwh <= 14) + || (minwh >= 19 && minwh <= 20) + || (minwh >= 23 && minwh <= 38) + || (minwh >= 43 && minwh <= 44) + || (minwh >= 47 && minwh <= 140); + if (num_output >= 8) return (minwh >= 11 && minwh <= 14) + || (minwh >= 19 && minwh <= 20) + || (minwh >= 31 && minwh <= 38) + || (minwh >= 43 && minwh <= 44) + || (minwh >= 55 && minwh <= 162); + return false; + } + if (num_input >= 16) + { + if (num_output >= 64) return false; + if (num_output >= 32) return (minwh >= 11 && minwh <= 14) + || (minwh >= 19 && minwh <= 20) + || (minwh >= 23 && minwh <= 44) + || (minwh >= 47 && minwh <= 92) + || (minwh >= 95 && minwh <= 188); + if (num_output >= 16) return (minwh >= 11 && minwh <= 14) + || (minwh >= 27 && minwh <= 38) + || (minwh >= 43 && minwh <= 44) + || (minwh >= 47 && minwh <= 74) + || (minwh >= 81 && minwh <= 110) + || (minwh >= 117 && minwh <= 170) + || (minwh >= 177 && minwh <= 182); + if (num_output >= 8) return (minwh >= 19 && minwh <= 20) + || (minwh >= 33 && minwh <= 38) + || (minwh >= 43 && minwh <= 44) + || (minwh >= 47 && minwh <= 128) + || (minwh >= 155 && minwh <= 210); + return false; + } + if (num_input >= 8) + { + if (num_output >= 64) return false; + if (num_output >= 32) return (minwh >= 7 && minwh <= 14) + || (minwh >= 17 && minwh <= 20) + || (minwh >= 23 && minwh <= 26) + || (minwh >= 31 && minwh <= 38) + || (minwh >= 43 && minwh <= 162); + if (num_output >= 16) return minwh == 31 || minwh == 32 + || (minwh >= 39 && minwh <= 44) + || (minwh >= 47 && minwh <= 212); + if (num_output >= 8) return false; + return false; + } + + return false; +} + +static bool test_prefer_winograd23(int num_input, int num_output, int w, int h) +{ + int minwh = std::min(w, h); + + if (num_input >= 512) + { + if (num_output >= 512) return (minwh >= 3 && minwh <= 14); + if (num_output >= 256) return (minwh >= 3 && minwh <= 14); + if (num_output >= 128) return (minwh >= 3 && minwh <= 14); + if (num_output >= 64) return (minwh >= 3 && minwh <= 8) || (minwh >= 11 && minwh <= 12); + if (num_output >= 32) return (minwh >= 3 && minwh <= 8); + if (num_output >= 16) return (minwh >= 3 && minwh <= 8); + if (num_output >= 8) return (minwh >= 3 && minwh <= 6); + return false; + } + if (num_input >= 256) + { + if (num_output >= 512) return (minwh >= 3 && minwh <= 14); + if (num_output >= 256) return (minwh >= 3 && minwh <= 14); + if (num_output >= 128) return (minwh >= 3 && minwh <= 12); + if (num_output >= 64) return (minwh >= 3 && minwh <= 4); + if (num_output >= 32) return (minwh >= 3 && minwh <= 8); + if (num_output >= 16) return (minwh >= 3 && minwh <= 8); + if (num_output >= 8) return (minwh >= 3 && minwh <= 6); + return false; + } + if (num_input >= 128) + { + if (num_output >= 512) return (minwh >= 3 && minwh <= 14); + if (num_output >= 256) return (minwh >= 3 && minwh <= 8) || (minwh >= 11 && minwh <= 12); + if (num_output >= 128) return (minwh >= 3 && minwh <= 10); + if (num_output >= 64) return (minwh >= 3 && minwh <= 8); + if (num_output >= 32) return (minwh >= 3 && minwh <= 10); + if (num_output >= 16) return (minwh >= 3 && minwh <= 6); + if (num_output >= 8) return (minwh >= 3 && minwh <= 6); + return false; + } + if (num_input >= 64) + { + if (num_output >= 512) return (minwh >= 3 && minwh <= 8) || (minwh >= 11 && minwh <= 12) || (minwh >= 15 && minwh <= 20); + if (num_output >= 256) return (minwh >= 7 && minwh <= 8); + if (num_output >= 128) return (minwh >= 3 && minwh <= 8) || (minwh >= 19 && minwh <= 22); + if (num_output >= 64) return (minwh >= 3 && minwh <= 12); + if (num_output >= 32) return (minwh >= 3 && minwh <= 12); + if (num_output >= 16) return (minwh >= 3 && minwh <= 12); + if (num_output >= 8) return (minwh >= 3 && minwh <= 12); + return false; + } + if (num_input >= 32) + { + if (num_output >= 512) return (minwh >= 3 && minwh <= 6) || (minwh >= 11 && minwh <= 12); + if (num_output >= 256) return (minwh >= 3 && minwh <= 6) || (minwh >= 11 && minwh <= 12); + if (num_output >= 128) return (minwh >= 3 && minwh <= 4) || (minwh >= 7 && minwh <= 16); + if (num_output >= 64) return (minwh >= 3 && minwh <= 8); + if (num_output >= 32) return (minwh >= 7 && minwh <= 8); + if (num_output >= 16) return (minwh >= 7 && minwh <= 8); + if (num_output >= 8) return (minwh >= 3 && minwh <= 10); + return false; + } + if (num_input >= 16) + { + if (num_output >= 512) return (minwh >= 11 && minwh <= 12); + if (num_output >= 256) return (minwh >= 3 && minwh <= 12); + if (num_output >= 128) return (minwh >= 3 && minwh <= 6) + || (minwh >= 9 && minwh <= 18); + if (num_output >= 64) return (minwh >= 3 && minwh <= 4) + || (minwh >= 7 && minwh <= 8) + || (minwh >= 11 && minwh <= 12) + || (minwh >= 15 && minwh <= 18); + if (num_output >= 32) return (minwh >= 3 && minwh <= 4) + || (minwh >= 9 && minwh <= 10); + if (num_output >= 16) return (minwh >= 3 && minwh <= 10); + if (num_output >= 8) return (minwh >= 3 && minwh <= 8) + || (minwh >= 11 && minwh <= 12); + return false; + } + if (num_input >= 8) + { + if (num_output >= 128) return false; + if (num_output >= 64) return (minwh >= 3 && minwh <= 4) + || (minwh >= 7 && minwh <= 14) + || (minwh >= 47 && minwh <= 48); + if (num_output >= 32) return (minwh >= 3 && minwh <= 6) + || (minwh >= 15 && minwh <= 16); + if (num_output >= 16) return (minwh >= 3 && minwh <= 6) + || (minwh >= 9 && minwh <= 14) + || (minwh >= 47 && minwh <= 212); + if (num_output >= 8) return true; + return false; + } + + return false; +} + int Convolution_x86::create_pipeline(const Option& opt) { if (dynamic_weight) return 0; activation = create_activation_layer(activation_type, activation_params, opt); + nT = opt.num_threads; #if NCNN_INT8 if (opt.use_int8_inference && weight_data.elemsize == (size_t)1u) @@ -226,14 +392,14 @@ int Convolution_x86::create_pipeline(const Option& opt) } #endif // __SSE2__ - bool prefer_winograd = (opt.use_winograd23_convolution || opt.use_winograd43_convolution || opt.use_winograd63_convolution) && num_input >= 16 && num_output >= 16; + bool prefer_winograd = (opt.use_winograd23_convolution || opt.use_winograd43_convolution || opt.use_winograd63_convolution) && (num_input > 8 || num_output > 8); if (opt.use_winograd_convolution && prefer_winograd && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { if ((bottom_shapes.empty() || bottom_shapes[0].w == 0 || bottom_shapes[0].h == 0) && (top_shapes.empty() || top_shapes[0].w == 0 || top_shapes[0].h == 0)) { // dynamic shape - if (opt.use_winograd63_convolution && num_input <= 24 && num_output <= 24) + if ((opt.use_winograd63_convolution) && (num_input <= 32 && num_output <= 32)) conv3x3s1_winograd63_transform_kernel(weight_data, weight_winograd63_data, num_input, num_output, opt); else if (opt.use_winograd43_convolution) conv3x3s1_winograd43_transform_kernel(weight_data, weight_winograd43_data, num_input, num_output, opt); @@ -242,20 +408,6 @@ int Convolution_x86::create_pipeline(const Option& opt) } else { - // winograd selection strategy - // - // | | | | | c/outc - // | | | | | f63 ^ - // | | | | | +----------------+128 - // | | | | |f63| f43 - // |f23|f43|f63|f43| +---+ +64 - // | | | | | f63 | f43 - // | | | | | +---+ +32 - // | | | | | f63 | f43 - // +---+---+---+---+---+---+---+--------+16 - // 0 14 19 21 31 96 132 192 --> wh - // - int w; int h; if (top_shapes.empty() || top_shapes[0].w == 0 || top_shapes[0].h == 0) @@ -283,18 +435,9 @@ int Convolution_x86::create_pipeline(const Option& opt) h = top_shapes[0].h + 2; } - const int minwh = std::min(w, h); - - bool prefer_winograd63 = minwh == 19 || minwh == 20 - || (minwh > 30 && num_input >= 128) - || (minwh > 30 && num_input >= 64 && num_input < 128 && num_output >= 128) - || (minwh > 30 && num_input >= 64 && num_input < 128 && num_output < 128 && minwh < 96) - || (minwh > 30 && num_input >= 16 && num_input < 64 && num_output >= 64) - || (minwh > 30 && num_input >= 32 && num_input < 64 && num_output < 64 && minwh < 132) - || (minwh > 30 && num_input >= 16 && num_input < 32 && num_output < 64 && minwh < 192); - - bool prefer_winograd43 = (minwh > 14 && !prefer_winograd63); - bool prefer_winograd23 = (!prefer_winograd43 && !prefer_winograd63); + bool prefer_winograd63 = test_prefer_winograd63(num_input, num_output, w, h); + bool prefer_winograd23 = test_prefer_winograd23(num_input, num_output, w, h); + bool prefer_winograd43 = !prefer_winograd63 && !prefer_winograd23; if (prefer_winograd23 && !opt.use_winograd23_convolution) { @@ -557,36 +700,13 @@ int Convolution_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option const int num_input = channels * elempack; - bool prefer_winograd = (opt.use_winograd23_convolution || opt.use_winograd43_convolution || opt.use_winograd63_convolution) && num_input >= 16 && num_output >= 16; + bool prefer_winograd = (opt.use_winograd23_convolution || opt.use_winograd43_convolution || opt.use_winograd63_convolution) && (num_input > 8 || num_output > 8); if (opt.use_winograd_convolution && prefer_winograd && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - // winograd selection strategy - // - // | | | | | c/outc - // | | | | | f63 ^ - // | | | | | +----------------+128 - // | | | | |f63| f43 - // |f23|f43|f63|f43| +---+ +64 - // | | | | | f63 | f43 - // | | | | | +---+ +32 - // | | | | | f63 | f43 - // +---+---+---+---+---+---+---+--------+16 - // 0 14 19 21 31 96 132 192 --> wh - // - - const int minwh = std::min(w, h); - - bool prefer_winograd63 = minwh == 19 || minwh == 20 - || (minwh > 30 && num_input >= 128) - || (minwh > 30 && num_input >= 64 && num_input < 128 && num_output >= 128) - || (minwh > 30 && num_input >= 64 && num_input < 128 && num_output < 128 && minwh < 96) - || (minwh > 30 && num_input >= 16 && num_input < 64 && num_output >= 64) - || (minwh > 30 && num_input >= 32 && num_input < 64 && num_output < 64 && minwh < 132) - || (minwh > 30 && num_input >= 16 && num_input < 32 && num_output < 64 && minwh < 192); - - bool prefer_winograd43 = (minwh > 14 && !prefer_winograd63); - bool prefer_winograd23 = (!prefer_winograd43 && !prefer_winograd63); + bool prefer_winograd63 = test_prefer_winograd63(num_input, num_output, w, h); + bool prefer_winograd23 = test_prefer_winograd23(num_input, num_output, w, h); + bool prefer_winograd43 = !prefer_winograd63 && !prefer_winograd23; if (prefer_winograd23 && (!opt.use_winograd23_convolution || weight_winograd23_data.empty())) { @@ -616,17 +736,25 @@ int Convolution_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option } } + int _nT = nT ? nT : opt.num_threads; + if (nT != 0 && opt.num_threads != nT) + { + // force num_threads the same as in create_pipeline + // so we could use pre-packed A/B from the same tile config + NCNN_LOGE("opt.num_threads %d changed, convolution winograd will use load-time value %d", opt.num_threads, nT); + } + if (prefer_winograd23) { - conv3x3s1_winograd23(bottom_blob_bordered, top_blob, weight_winograd23_data, bias_data, opt); + conv3x3s1_winograd23(bottom_blob_bordered, top_blob, weight_winograd23_data, bias_data, _nT, opt); } else if (prefer_winograd43) { - conv3x3s1_winograd43(bottom_blob_bordered, top_blob, weight_winograd43_data, bias_data, opt); + conv3x3s1_winograd43(bottom_blob_bordered, top_blob, weight_winograd43_data, bias_data, _nT, opt); } else if (prefer_winograd63) { - conv3x3s1_winograd63(bottom_blob_bordered, top_blob, weight_winograd63_data, bias_data, opt); + conv3x3s1_winograd63(bottom_blob_bordered, top_blob, weight_winograd63_data, bias_data, _nT, opt); } else { diff --git a/src/layer/x86/convolution_x86.h b/src/layer/x86/convolution_x86.h index af548d935..44889ef5a 100644 --- a/src/layer/x86/convolution_x86.h +++ b/src/layer/x86/convolution_x86.h @@ -41,6 +41,7 @@ protected: public: Layer* activation; + int nT; Mat weight_data_tm; Mat weight_sgemm_data; Mat weight_winograd23_data; diff --git a/src/layer/x86/gemm_x86.cpp b/src/layer/x86/gemm_x86.cpp index c595c0fab..430435478 100644 --- a/src/layer/x86/gemm_x86.cpp +++ b/src/layer/x86/gemm_x86.cpp @@ -5873,6 +5873,7 @@ static int gemm_x86(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int int nn_M = (M + TILE_M - 1) / TILE_M; int nn_N = (N + TILE_N - 1) / TILE_N; + int nn_K = (K + TILE_K - 1) / TILE_K; Mat ATX(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 4u, opt.blob_allocator); Mat BT(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 4u, opt.blob_allocator); @@ -5881,27 +5882,30 @@ static int gemm_x86(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int if (K > TILE_K) tmpX.create(TILE_N, TILE_M, nT, 4u, opt.blob_allocator); + const int nn_NK = nn_N * nn_K; + // pack B #pragma omp parallel for num_threads(nT) - for (int ppj = 0; ppj < nn_N; ppj++) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; - for (int k = 0; k < K; k += TILE_K) - { - const int max_jj = std::min((N - j), TILE_N); - const int max_kk = std::min((K - k), TILE_K); + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); - Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); - if (transB) - { - pack_B_tile(B, BT_tile, j, max_jj, k, max_kk); - } - else - { - transpose_pack_B_tile(B, BT_tile, j, max_jj, k, max_kk); - } + if (transB) + { + pack_B_tile(B, BT_tile, j, max_jj, k, max_kk); + } + else + { + transpose_pack_B_tile(B, BT_tile, j, max_jj, k, max_kk); } } @@ -5965,6 +5969,7 @@ static int gemm_AT_x86(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int nn_M = (M + TILE_M - 1) / TILE_M; int nn_N = (N + TILE_N - 1) / TILE_N; + int nn_K = (K + TILE_K - 1) / TILE_K; Mat BT(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 4u, opt.blob_allocator); @@ -5972,27 +5977,30 @@ static int gemm_AT_x86(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, if (K > TILE_K) tmpX.create(TILE_N, TILE_M, nT, 4u, opt.blob_allocator); + const int nn_NK = nn_N * nn_K; + // pack B #pragma omp parallel for num_threads(nT) - for (int ppj = 0; ppj < nn_N; ppj++) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; - for (int k = 0; k < K; k += TILE_K) - { - const int max_jj = std::min((N - j), TILE_N); - const int max_kk = std::min((K - k), TILE_K); + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); - Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); - if (transB) - { - pack_B_tile(B, BT_tile, j, max_jj, k, max_kk); - } - else - { - transpose_pack_B_tile(B, BT_tile, j, max_jj, k, max_kk); - } + if (transB) + { + pack_B_tile(B, BT_tile, j, max_jj, k, max_kk); + } + else + { + transpose_pack_B_tile(B, BT_tile, j, max_jj, k, max_kk); } } @@ -6203,31 +6211,35 @@ int Gemm_x86::create_pipeline(const Option& opt) get_optimal_tile_mnk(0, N, K, TILE_M, TILE_N, TILE_K, opt.num_threads); const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; BT_data.create(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 4u, opt.blob_allocator); if (BT_data.empty()) return -100; - #pragma omp parallel for num_threads(opt.num_threads) - for (int ppj = 0; ppj < nn_N; ppj++) + const int nn_NK = nn_N * nn_K; + + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; - for (int k = 0; k < K; k += TILE_K) - { - const int max_jj = std::min((N - j), TILE_N); - const int max_kk = std::min((K - k), TILE_K); + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); - Mat BT_tile = BT_data.channel(j / TILE_N).row_range(k / TILE_K, 1); + Mat BT_tile = BT_data.channel(j / TILE_N).row_range(k / TILE_K, 1); - if (transB) - { - pack_B_tile(B_data, BT_tile, j, max_jj, k, max_kk); - } - else - { - transpose_pack_B_tile(B_data, BT_tile, j, max_jj, k, max_kk); - } + if (transB) + { + pack_B_tile(B_data, BT_tile, j, max_jj, k, max_kk); + } + else + { + transpose_pack_B_tile(B_data, BT_tile, j, max_jj, k, max_kk); } }