diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index aff757ede..134ec8c07 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -707,7 +707,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int { if (elempack == 4) { - const float* p0 = (const float*)B + k / 4 * 4 * B_hstep + (j + jj) * 4; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; int kk = 0; for (; kk + 3 < max_kk; kk += 4) @@ -750,7 +750,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int { if (elempack == 4) { - const float* p0 = (const float*)B + k / 4 * 4 * B_hstep + (j + jj) * 4; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; int kk = 0; for (; kk + 3 < max_kk; kk += 4) @@ -787,7 +787,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int { if (elempack == 4) { - const float* p0 = (const float*)B + k / 4 * 4 * B_hstep + (j + jj) * 4; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; int kk = 0; for (; kk + 3 < max_kk; kk += 4) @@ -820,7 +820,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #if __ARM_NEON if (elempack == 4) { - const float* p0 = (const float*)B + k / 4 * 4 * B_hstep + (j + jj) * 4; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; int kk = 0; for (; kk + 3 < max_kk; kk += 4) @@ -853,7 +853,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #if __ARM_NEON if (elempack == 4) { - const float* p0 = (const float*)B + k / 4 * 4 * B_hstep + (j + jj) * 4; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; int kk = 0; for (; kk + 3 < max_kk; kk += 4) @@ -3789,11 +3789,11 @@ static void get_optimal_tile_mnk(int M, int N, int K, int& TILE_M, int& TILE_N, #if __aarch64__ TILE_M = tile_size / 8 * 8; - TILE_N = tile_size; + TILE_N = tile_size / 4 * 4; TILE_K = tile_size / 8 * 8; #elif __ARM_NEON TILE_M = tile_size / 4 * 4; - TILE_N = tile_size; + TILE_N = tile_size / 4 * 4; TILE_K = tile_size / 4 * 4; #else TILE_M = tile_size / 2 * 2; @@ -3818,10 +3818,10 @@ static void get_optimal_tile_mnk(int M, int N, int K, int& TILE_M, int& TILE_N, #if __aarch64__ TILE_M = tile_size / 8 * 8; - TILE_N = tile_size; + TILE_N = tile_size / 4 * 4; #elif __ARM_NEON TILE_M = tile_size / 4 * 4; - TILE_N = tile_size; + TILE_N = tile_size / 4 * 4; #else TILE_M = tile_size / 2 * 2; TILE_N = tile_size; @@ -3846,7 +3846,13 @@ static void get_optimal_tile_mnk(int M, int N, int K, int& TILE_M, int& TILE_N, if (N > 0) { int nn_N = (N + TILE_N - 1) / TILE_N; +#if __aarch64__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#elif __ARM_NEON + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#else TILE_N = std::min(TILE_N, (N + nn_N - 1) / nn_N); +#endif } if (nT > 1) diff --git a/src/layer/x86/gemm_x86.cpp b/src/layer/x86/gemm_x86.cpp index 75692c997..8b47e300d 100644 --- a/src/layer/x86/gemm_x86.cpp +++ b/src/layer/x86/gemm_x86.cpp @@ -1505,7 +1505,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #if __AVX512F__ if (elempack == 16) { - const float* p0 = (const float*)B + k / 16 * 16 * B_hstep + (j + jj) * 16; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 16; int kk = 0; for (; kk + 15 < max_kk; kk += 16) @@ -1542,7 +1542,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #endif // __AVX512F__ if (elempack == 8) { - const float* p0 = (const float*)B + k / 8 * 8 * B_hstep + (j + jj) * 8; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 8; int kk = 0; for (; kk + 7 < max_kk; kk += 8) @@ -1579,7 +1579,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #endif // __AVX__ if (elempack == 4) { - const float* p0 = (const float*)B + k / 4 * 4 * B_hstep + (j + jj) * 4; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; int kk = 0; for (; kk + 3 < max_kk; kk += 4) @@ -1636,7 +1636,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #if __AVX512F__ if (elempack == 16) { - const float* p0 = (const float*)B + k / 16 * 16 * B_hstep + (j + jj) * 16; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 16; int kk = 0; for (; kk + 15 < max_kk; kk += 16) @@ -1665,7 +1665,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #endif // __AVX512F__ if (elempack == 8) { - const float* p0 = (const float*)B + k / 8 * 8 * B_hstep + (j + jj) * 8; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 8; int kk = 0; for (; kk + 7 < max_kk; kk += 8) @@ -1694,7 +1694,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #endif // __AVX__ if (elempack == 4) { - const float* p0 = (const float*)B + k / 4 * 4 * B_hstep + (j + jj) * 4; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; int kk = 0; for (; kk + 3 < max_kk; kk += 4) @@ -1741,7 +1741,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #if __AVX512F__ if (elempack == 16) { - const float* p0 = (const float*)B + k / 16 * 16 * B_hstep + (j + jj) * 16; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 16; int kk = 0; for (; kk + 15 < max_kk; kk += 16) @@ -1762,7 +1762,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #endif // __AVX512F__ if (elempack == 8) { - const float* p0 = (const float*)B + k / 8 * 8 * B_hstep + (j + jj) * 8; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 8; int kk = 0; for (; kk + 7 < max_kk; kk += 8) @@ -1783,7 +1783,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #endif // __AVX__ if (elempack == 4) { - const float* p0 = (const float*)B + k / 4 * 4 * B_hstep + (j + jj) * 4; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; int kk = 0; for (; kk + 3 < max_kk; kk += 4) @@ -1822,7 +1822,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #if __AVX512F__ if (elempack == 16) { - const float* p0 = (const float*)B + k / 16 * 16 * B_hstep + (j + jj) * 16; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 16; int kk = 0; for (; kk + 15 < max_kk; kk += 16) @@ -1839,7 +1839,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #endif // __AVX512F__ if (elempack == 8) { - const float* p0 = (const float*)B + k / 8 * 8 * B_hstep + (j + jj) * 8; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 8; int kk = 0; for (; kk + 7 < max_kk; kk += 8) @@ -1856,7 +1856,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #endif // __AVX__ if (elempack == 4) { - const float* p0 = (const float*)B + k / 4 * 4 * B_hstep + (j + jj) * 4; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; int kk = 0; for (; kk + 3 < max_kk; kk += 4) @@ -1893,7 +1893,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #if __AVX512F__ if (elempack == 16) { - const float* p0 = (const float*)B + k / 16 * 16 * B_hstep + (j + jj) * 16; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 16; int kk = 0; for (; kk + 15 < max_kk; kk += 16) @@ -1906,7 +1906,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #endif // __AVX512F__ if (elempack == 8) { - const float* p0 = (const float*)B + k / 8 * 8 * B_hstep + (j + jj) * 8; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 8; int kk = 0; for (; kk + 7 < max_kk; kk += 8) @@ -1919,7 +1919,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #endif // __AVX__ if (elempack == 4) { - const float* p0 = (const float*)B + k / 4 * 4 * B_hstep + (j + jj) * 4; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; int kk = 0; for (; kk + 3 < max_kk; kk += 4) @@ -2066,143 +2066,105 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } if (out_elempack == 8) { - __m256 _sum0_0 = _mm256_loadu_ps(pC); - __m256 _sum1_0 = _mm256_loadu_ps(pC + 8); - __m256 _sum2_0 = _mm256_loadu_ps(pC + 8 * 2); - __m256 _sum3_0 = _mm256_loadu_ps(pC + 8 * 3); - __m256 _sum4_0 = _mm256_loadu_ps(pC + 8 * 4); - __m256 _sum5_0 = _mm256_loadu_ps(pC + 8 * 5); - __m256 _sum6_0 = _mm256_loadu_ps(pC + 8 * 6); - __m256 _sum7_0 = _mm256_loadu_ps(pC + 8 * 7); - __m256 _sum8_0 = _mm256_loadu_ps(pC + 8 * 8); - __m256 _sum9_0 = _mm256_loadu_ps(pC + 8 * 9); - __m256 _suma_0 = _mm256_loadu_ps(pC + 8 * 10); - __m256 _sumb_0 = _mm256_loadu_ps(pC + 8 * 11); - __m256 _sum0_1 = _mm256_loadu_ps(pC + N * 8); - __m256 _sum1_1 = _mm256_loadu_ps(pC + N * 8 + 8); - __m256 _sum2_1 = _mm256_loadu_ps(pC + N * 8 + 8 * 2); - __m256 _sum3_1 = _mm256_loadu_ps(pC + N * 8 + 8 * 3); - __m256 _sum4_1 = _mm256_loadu_ps(pC + N * 8 + 8 * 4); - __m256 _sum5_1 = _mm256_loadu_ps(pC + N * 8 + 8 * 5); - __m256 _sum6_1 = _mm256_loadu_ps(pC + N * 8 + 8 * 6); - __m256 _sum7_1 = _mm256_loadu_ps(pC + N * 8 + 8 * 7); - __m256 _sum8_1 = _mm256_loadu_ps(pC + N * 8 + 8 * 8); - __m256 _sum9_1 = _mm256_loadu_ps(pC + N * 8 + 8 * 9); - __m256 _suma_1 = _mm256_loadu_ps(pC + N * 8 + 8 * 10); - __m256 _sumb_1 = _mm256_loadu_ps(pC + N * 8 + 8 * 11); - _sum0 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum0_0), _sum0_1, 1); - _sum1 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum1_0), _sum1_1, 1); - _sum2 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum2_0), _sum2_1, 1); - _sum3 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum3_0), _sum3_1, 1); - _sum4 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum4_0), _sum4_1, 1); - _sum5 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum5_0), _sum5_1, 1); - _sum6 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum6_0), _sum6_1, 1); - _sum7 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum7_0), _sum7_1, 1); - _sum8 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum8_0), _sum8_1, 1); - _sum9 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum9_0), _sum9_1, 1); - _suma = _mm512_insertf32x8(_mm512_castps256_ps512(_suma_0), _suma_1, 1); - _sumb = _mm512_insertf32x8(_mm512_castps256_ps512(_sumb_0), _sumb_1, 1); + __m512 _tmp0 = _mm512_loadu_ps(pC); + __m512 _tmp1 = _mm512_loadu_ps(pC + 16); + __m512 _tmp2 = _mm512_loadu_ps(pC + 32); + __m512 _tmp3 = _mm512_loadu_ps(pC + 48); + __m512 _tmp4 = _mm512_loadu_ps(pC + 64); + __m512 _tmp5 = _mm512_loadu_ps(pC + 80); + __m512 _tmp6 = _mm512_loadu_ps(pC + N * 8); + __m512 _tmp7 = _mm512_loadu_ps(pC + N * 8 + 16); + __m512 _tmp8 = _mm512_loadu_ps(pC + N * 8 + 32); + __m512 _tmp9 = _mm512_loadu_ps(pC + N * 8 + 48); + __m512 _tmpa = _mm512_loadu_ps(pC + N * 8 + 64); + __m512 _tmpb = _mm512_loadu_ps(pC + N * 8 + 80); + + _sum0 = _mm512_shuffle_f32x4(_tmp0, _tmp6, _MM_SHUFFLE(1, 0, 1, 0)); + _sum1 = _mm512_shuffle_f32x4(_tmp0, _tmp6, _MM_SHUFFLE(3, 2, 3, 2)); + _sum2 = _mm512_shuffle_f32x4(_tmp1, _tmp7, _MM_SHUFFLE(1, 0, 1, 0)); + _sum3 = _mm512_shuffle_f32x4(_tmp1, _tmp7, _MM_SHUFFLE(3, 2, 3, 2)); + _sum4 = _mm512_shuffle_f32x4(_tmp2, _tmp8, _MM_SHUFFLE(1, 0, 1, 0)); + _sum5 = _mm512_shuffle_f32x4(_tmp2, _tmp8, _MM_SHUFFLE(3, 2, 3, 2)); + _sum6 = _mm512_shuffle_f32x4(_tmp3, _tmp9, _MM_SHUFFLE(1, 0, 1, 0)); + _sum7 = _mm512_shuffle_f32x4(_tmp3, _tmp9, _MM_SHUFFLE(3, 2, 3, 2)); + _sum8 = _mm512_shuffle_f32x4(_tmp4, _tmpa, _MM_SHUFFLE(1, 0, 1, 0)); + _sum9 = _mm512_shuffle_f32x4(_tmp4, _tmpa, _MM_SHUFFLE(3, 2, 3, 2)); + _suma = _mm512_shuffle_f32x4(_tmp5, _tmpb, _MM_SHUFFLE(1, 0, 1, 0)); + _sumb = _mm512_shuffle_f32x4(_tmp5, _tmpb, _MM_SHUFFLE(3, 2, 3, 2)); pC += 96; } if (out_elempack == 4) { - __m128 _sum0_0 = _mm_loadu_ps(pC); - __m128 _sum1_0 = _mm_loadu_ps(pC + 4); - __m128 _sum2_0 = _mm_loadu_ps(pC + 4 * 2); - __m128 _sum3_0 = _mm_loadu_ps(pC + 4 * 3); - __m128 _sum4_0 = _mm_loadu_ps(pC + 4 * 4); - __m128 _sum5_0 = _mm_loadu_ps(pC + 4 * 5); - __m128 _sum6_0 = _mm_loadu_ps(pC + 4 * 6); - __m128 _sum7_0 = _mm_loadu_ps(pC + 4 * 7); - __m128 _sum8_0 = _mm_loadu_ps(pC + 4 * 8); - __m128 _sum9_0 = _mm_loadu_ps(pC + 4 * 9); - __m128 _suma_0 = _mm_loadu_ps(pC + 4 * 10); - __m128 _sumb_0 = _mm_loadu_ps(pC + 4 * 11); - __m128 _sum0_1 = _mm_loadu_ps(pC + N * 4); - __m128 _sum1_1 = _mm_loadu_ps(pC + N * 4 + 4); - __m128 _sum2_1 = _mm_loadu_ps(pC + N * 4 + 4 * 2); - __m128 _sum3_1 = _mm_loadu_ps(pC + N * 4 + 4 * 3); - __m128 _sum4_1 = _mm_loadu_ps(pC + N * 4 + 4 * 4); - __m128 _sum5_1 = _mm_loadu_ps(pC + N * 4 + 4 * 5); - __m128 _sum6_1 = _mm_loadu_ps(pC + N * 4 + 4 * 6); - __m128 _sum7_1 = _mm_loadu_ps(pC + N * 4 + 4 * 7); - __m128 _sum8_1 = _mm_loadu_ps(pC + N * 4 + 4 * 8); - __m128 _sum9_1 = _mm_loadu_ps(pC + N * 4 + 4 * 9); - __m128 _suma_1 = _mm_loadu_ps(pC + N * 4 + 4 * 10); - __m128 _sumb_1 = _mm_loadu_ps(pC + N * 4 + 4 * 11); - __m128 _sum0_2 = _mm_loadu_ps(pC + N * 8); - __m128 _sum1_2 = _mm_loadu_ps(pC + N * 8 + 4); - __m128 _sum2_2 = _mm_loadu_ps(pC + N * 8 + 4 * 2); - __m128 _sum3_2 = _mm_loadu_ps(pC + N * 8 + 4 * 3); - __m128 _sum4_2 = _mm_loadu_ps(pC + N * 8 + 4 * 4); - __m128 _sum5_2 = _mm_loadu_ps(pC + N * 8 + 4 * 5); - __m128 _sum6_2 = _mm_loadu_ps(pC + N * 8 + 4 * 6); - __m128 _sum7_2 = _mm_loadu_ps(pC + N * 8 + 4 * 7); - __m128 _sum8_2 = _mm_loadu_ps(pC + N * 8 + 4 * 8); - __m128 _sum9_2 = _mm_loadu_ps(pC + N * 8 + 4 * 9); - __m128 _suma_2 = _mm_loadu_ps(pC + N * 8 + 4 * 10); - __m128 _sumb_2 = _mm_loadu_ps(pC + N * 8 + 4 * 11); - __m128 _sum0_3 = _mm_loadu_ps(pC + N * 12); - __m128 _sum1_3 = _mm_loadu_ps(pC + N * 12 + 4); - __m128 _sum2_3 = _mm_loadu_ps(pC + N * 12 + 4 * 2); - __m128 _sum3_3 = _mm_loadu_ps(pC + N * 12 + 4 * 3); - __m128 _sum4_3 = _mm_loadu_ps(pC + N * 12 + 4 * 4); - __m128 _sum5_3 = _mm_loadu_ps(pC + N * 12 + 4 * 5); - __m128 _sum6_3 = _mm_loadu_ps(pC + N * 12 + 4 * 6); - __m128 _sum7_3 = _mm_loadu_ps(pC + N * 12 + 4 * 7); - __m128 _sum8_3 = _mm_loadu_ps(pC + N * 12 + 4 * 8); - __m128 _sum9_3 = _mm_loadu_ps(pC + N * 12 + 4 * 9); - __m128 _suma_3 = _mm_loadu_ps(pC + N * 12 + 4 * 10); - __m128 _sumb_3 = _mm_loadu_ps(pC + N * 12 + 4 * 11); - _sum0 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum0_0), _sum0_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum0_2), _sum0_3, 1), 1); - _sum1 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum1_0), _sum1_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum1_2), _sum1_3, 1), 1); - _sum2 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum2_0), _sum2_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum2_2), _sum2_3, 1), 1); - _sum3 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum3_0), _sum3_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum3_2), _sum3_3, 1), 1); - _sum4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum4_0), _sum4_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum4_2), _sum4_3, 1), 1); - _sum5 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum5_0), _sum5_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum5_2), _sum5_3, 1), 1); - _sum6 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum6_0), _sum6_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum6_2), _sum6_3, 1), 1); - _sum7 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum7_0), _sum7_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum7_2), _sum7_3, 1), 1); - _sum8 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum8_0), _sum8_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum8_2), _sum8_3, 1), 1); - _sum9 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum9_0), _sum9_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum9_2), _sum9_3, 1), 1); - _suma = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_suma_0), _suma_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_suma_2), _suma_3, 1), 1); - _sumb = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sumb_0), _sumb_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sumb_2), _sumb_3, 1), 1); + _sum0 = _mm512_loadu_ps(pC); + _sum1 = _mm512_loadu_ps(pC + 16); + _sum2 = _mm512_loadu_ps(pC + 32); + _sum3 = _mm512_loadu_ps(pC + N * 4); + _sum4 = _mm512_loadu_ps(pC + N * 4 + 16); + _sum5 = _mm512_loadu_ps(pC + N * 4 + 32); + _sum6 = _mm512_loadu_ps(pC + N * 8); + _sum7 = _mm512_loadu_ps(pC + N * 8 + 16); + _sum8 = _mm512_loadu_ps(pC + N * 8 + 32); + _sum9 = _mm512_loadu_ps(pC + N * 12); + _suma = _mm512_loadu_ps(pC + N * 12 + 16); + _sumb = _mm512_loadu_ps(pC + N * 12 + 32); + + __m512 _tmp0 = _mm512_shuffle_f32x4(_sum0, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_sum6, _sum9, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_sum0, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_sum6, _sum9, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_sum1, _sum4, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_sum7, _suma, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_sum1, _sum4, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_sum7, _suma, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp8 = _mm512_shuffle_f32x4(_sum2, _sum5, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp9 = _mm512_shuffle_f32x4(_sum8, _sumb, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpa = _mm512_shuffle_f32x4(_sum2, _sum5, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpb = _mm512_shuffle_f32x4(_sum8, _sumb, _MM_SHUFFLE(3, 2, 3, 2)); + + _sum0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _sum2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _sum4 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _sum5 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _sum6 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _sum7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + _sum8 = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(2, 0, 2, 0)); + _sum9 = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(3, 1, 3, 1)); + _suma = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(2, 0, 2, 0)); + _sumb = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(3, 1, 3, 1)); + pC += 48; } if (out_elempack == 1) { - __m128 _sum0_0 = _mm_loadu_ps(pC); - __m128 _sum1_0 = _mm_loadu_ps(pC + N); - __m128 _sum2_0 = _mm_loadu_ps(pC + N * 2); - __m128 _sum3_0 = _mm_loadu_ps(pC + N * 3); - __m128 _sum4_0 = _mm_loadu_ps(pC + N * 4); - __m128 _sum5_0 = _mm_loadu_ps(pC + N * 5); - __m128 _sum6_0 = _mm_loadu_ps(pC + N * 6); - __m128 _sum7_0 = _mm_loadu_ps(pC + N * 7); - __m128 _sum8_0 = _mm_loadu_ps(pC + N * 8); - __m128 _sum9_0 = _mm_loadu_ps(pC + N * 9); - __m128 _suma_0 = _mm_loadu_ps(pC + N * 10); - __m128 _sumb_0 = _mm_loadu_ps(pC + N * 11); - __m128 _sumc_0 = _mm_loadu_ps(pC + N * 12); - __m128 _sumd_0 = _mm_loadu_ps(pC + N * 13); - __m128 _sume_0 = _mm_loadu_ps(pC + N * 14); - __m128 _sumf_0 = _mm_loadu_ps(pC + N * 15); + __m256 _r0 = _mm256_loadu_ps(pC); + __m256 _r1 = _mm256_loadu_ps(pC + N); + __m256 _r2 = _mm256_loadu_ps(pC + N * 2); + __m256 _r3 = _mm256_loadu_ps(pC + N * 3); + __m256 _r4 = _mm256_loadu_ps(pC + N * 4); + __m256 _r5 = _mm256_loadu_ps(pC + N * 5); + __m256 _r6 = _mm256_loadu_ps(pC + N * 6); + __m256 _r7 = _mm256_loadu_ps(pC + N * 7); + __m256 _r8 = _mm256_loadu_ps(pC + N * 8); + __m256 _r9 = _mm256_loadu_ps(pC + N * 9); + __m256 _ra = _mm256_loadu_ps(pC + N * 10); + __m256 _rb = _mm256_loadu_ps(pC + N * 11); + __m256 _rc = _mm256_loadu_ps(pC + N * 12); + __m256 _rd = _mm256_loadu_ps(pC + N * 13); + __m256 _re = _mm256_loadu_ps(pC + N * 14); + __m256 _rf = _mm256_loadu_ps(pC + N * 15); + + transpose8x16_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb, _rc, _rd, _re, _rf); - __m128 _sum0_1 = _mm_loadu_ps(pC + 4); - __m128 _sum1_1 = _mm_loadu_ps(pC + N + 4); - __m128 _sum2_1 = _mm_loadu_ps(pC + N * 2 + 4); - __m128 _sum3_1 = _mm_loadu_ps(pC + N * 3 + 4); - __m128 _sum4_1 = _mm_loadu_ps(pC + N * 4 + 4); - __m128 _sum5_1 = _mm_loadu_ps(pC + N * 5 + 4); - __m128 _sum6_1 = _mm_loadu_ps(pC + N * 6 + 4); - __m128 _sum7_1 = _mm_loadu_ps(pC + N * 7 + 4); - __m128 _sum8_1 = _mm_loadu_ps(pC + N * 8 + 4); - __m128 _sum9_1 = _mm_loadu_ps(pC + N * 9 + 4); - __m128 _suma_1 = _mm_loadu_ps(pC + N * 10 + 4); - __m128 _sumb_1 = _mm_loadu_ps(pC + N * 11 + 4); - __m128 _sumc_1 = _mm_loadu_ps(pC + N * 12 + 4); - __m128 _sumd_1 = _mm_loadu_ps(pC + N * 13 + 4); - __m128 _sume_1 = _mm_loadu_ps(pC + N * 14 + 4); - __m128 _sumf_1 = _mm_loadu_ps(pC + N * 15 + 4); + _sum0 = _mm512_insertf32x8(_mm512_castps256_ps512(_r0), _r1, 1); + _sum1 = _mm512_insertf32x8(_mm512_castps256_ps512(_r2), _r3, 1); + _sum2 = _mm512_insertf32x8(_mm512_castps256_ps512(_r4), _r5, 1); + _sum3 = _mm512_insertf32x8(_mm512_castps256_ps512(_r6), _r7, 1); + _sum4 = _mm512_insertf32x8(_mm512_castps256_ps512(_r8), _r9, 1); + _sum5 = _mm512_insertf32x8(_mm512_castps256_ps512(_ra), _rb, 1); + _sum6 = _mm512_insertf32x8(_mm512_castps256_ps512(_rc), _rd, 1); + _sum7 = _mm512_insertf32x8(_mm512_castps256_ps512(_re), _rf, 1); __m128 _sum0_2 = _mm_loadu_ps(pC + 8); __m128 _sum1_2 = _mm_loadu_ps(pC + N + 8); @@ -2221,31 +2183,11 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons __m128 _sume_2 = _mm_loadu_ps(pC + N * 14 + 8); __m128 _sumf_2 = _mm_loadu_ps(pC + N * 15 + 8); - _MM_TRANSPOSE4_PS(_sum0_0, _sum1_0, _sum2_0, _sum3_0); - _MM_TRANSPOSE4_PS(_sum4_0, _sum5_0, _sum6_0, _sum7_0); - _MM_TRANSPOSE4_PS(_sum8_0, _sum9_0, _suma_0, _sumb_0); - _MM_TRANSPOSE4_PS(_sumc_0, _sumd_0, _sume_0, _sumf_0); - - _MM_TRANSPOSE4_PS(_sum0_1, _sum1_1, _sum2_1, _sum3_1); - _MM_TRANSPOSE4_PS(_sum4_1, _sum5_1, _sum6_1, _sum7_1); - _MM_TRANSPOSE4_PS(_sum8_1, _sum9_1, _suma_1, _sumb_1); - _MM_TRANSPOSE4_PS(_sumc_1, _sumd_1, _sume_1, _sumf_1); - _MM_TRANSPOSE4_PS(_sum0_2, _sum1_2, _sum2_2, _sum3_2); _MM_TRANSPOSE4_PS(_sum4_2, _sum5_2, _sum6_2, _sum7_2); _MM_TRANSPOSE4_PS(_sum8_2, _sum9_2, _suma_2, _sumb_2); _MM_TRANSPOSE4_PS(_sumc_2, _sumd_2, _sume_2, _sumf_2); - _sum0 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum0_0), _sum4_0, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum8_0), _sumc_0, 1), 1); - _sum1 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum1_0), _sum5_0, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum9_0), _sumd_0, 1), 1); - _sum2 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum2_0), _sum6_0, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_suma_0), _sume_0, 1), 1); - _sum3 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum3_0), _sum7_0, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sumb_0), _sumf_0, 1), 1); - - _sum4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum0_1), _sum4_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum8_1), _sumc_1, 1), 1); - _sum5 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum1_1), _sum5_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum9_1), _sumd_1, 1), 1); - _sum6 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum2_1), _sum6_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_suma_1), _sume_1, 1), 1); - _sum7 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum3_1), _sum7_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sumb_1), _sumf_1, 1), 1); - _sum8 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum0_2), _sum4_2, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum8_2), _sumc_2, 1), 1); _sum9 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum1_2), _sum5_2, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum9_2), _sumd_2, 1), 1); _suma = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum2_2), _sum6_2, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_suma_2), _sume_2, 1), 1); @@ -2330,87 +2272,75 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } if (out_elempack == 8) { - _mm256_store_ps(outptr0, _mm512_extractf32x8_ps(_sum0, 0)); - _mm256_store_ps(outptr0 + 8 * 1, _mm512_extractf32x8_ps(_sum1, 0)); - _mm256_store_ps(outptr0 + 8 * 2, _mm512_extractf32x8_ps(_sum2, 0)); - _mm256_store_ps(outptr0 + 8 * 3, _mm512_extractf32x8_ps(_sum3, 0)); - _mm256_store_ps(outptr0 + 8 * 4, _mm512_extractf32x8_ps(_sum4, 0)); - _mm256_store_ps(outptr0 + 8 * 5, _mm512_extractf32x8_ps(_sum5, 0)); - _mm256_store_ps(outptr0 + 8 * 6, _mm512_extractf32x8_ps(_sum6, 0)); - _mm256_store_ps(outptr0 + 8 * 7, _mm512_extractf32x8_ps(_sum7, 0)); - _mm256_store_ps(outptr0 + 8 * 8, _mm512_extractf32x8_ps(_sum8, 0)); - _mm256_store_ps(outptr0 + 8 * 9, _mm512_extractf32x8_ps(_sum9, 0)); - _mm256_store_ps(outptr0 + 8 * 10, _mm512_extractf32x8_ps(_suma, 0)); - _mm256_store_ps(outptr0 + 8 * 11, _mm512_extractf32x8_ps(_sumb, 0)); - - _mm256_store_ps(outptr0 + out_hstep * 8, _mm512_extractf32x8_ps(_sum0, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 1, _mm512_extractf32x8_ps(_sum1, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 2, _mm512_extractf32x8_ps(_sum2, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 3, _mm512_extractf32x8_ps(_sum3, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 4, _mm512_extractf32x8_ps(_sum4, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 5, _mm512_extractf32x8_ps(_sum5, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 6, _mm512_extractf32x8_ps(_sum6, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 7, _mm512_extractf32x8_ps(_sum7, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 8, _mm512_extractf32x8_ps(_sum8, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 9, _mm512_extractf32x8_ps(_sum9, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 10, _mm512_extractf32x8_ps(_suma, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 11, _mm512_extractf32x8_ps(_sumb, 1)); + __m512 _tmp0 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_sum4, _sum5, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_sum6, _sum7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_sum8, _sum9, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_suma, _sumb, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp8 = _mm512_shuffle_f32x4(_sum4, _sum5, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp9 = _mm512_shuffle_f32x4(_sum6, _sum7, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpa = _mm512_shuffle_f32x4(_sum8, _sum9, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpb = _mm512_shuffle_f32x4(_suma, _sumb, _MM_SHUFFLE(3, 2, 3, 2)); + + _mm512_storeu_ps(outptr0, _tmp0); + _mm512_storeu_ps(outptr0 + 16, _tmp1); + _mm512_storeu_ps(outptr0 + 16 * 2, _tmp2); + _mm512_storeu_ps(outptr0 + 16 * 3, _tmp3); + _mm512_storeu_ps(outptr0 + 16 * 4, _tmp4); + _mm512_storeu_ps(outptr0 + 16 * 5, _tmp5); + + _mm512_storeu_ps(outptr0 + out_hstep * 8, _tmp6); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16, _tmp7); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16 * 2, _tmp8); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16 * 3, _tmp9); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16 * 4, _tmpa); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16 * 5, _tmpb); outptr0 += 96; } if (out_elempack == 4) { - _mm_store_ps(outptr0, _mm512_extractf32x4_ps(_sum0, 0)); - _mm_store_ps(outptr0 + 4 * 1, _mm512_extractf32x4_ps(_sum1, 0)); - _mm_store_ps(outptr0 + 4 * 2, _mm512_extractf32x4_ps(_sum2, 0)); - _mm_store_ps(outptr0 + 4 * 3, _mm512_extractf32x4_ps(_sum3, 0)); - _mm_store_ps(outptr0 + 4 * 4, _mm512_extractf32x4_ps(_sum4, 0)); - _mm_store_ps(outptr0 + 4 * 5, _mm512_extractf32x4_ps(_sum5, 0)); - _mm_store_ps(outptr0 + 4 * 6, _mm512_extractf32x4_ps(_sum6, 0)); - _mm_store_ps(outptr0 + 4 * 7, _mm512_extractf32x4_ps(_sum7, 0)); - _mm_store_ps(outptr0 + 4 * 8, _mm512_extractf32x4_ps(_sum8, 0)); - _mm_store_ps(outptr0 + 4 * 9, _mm512_extractf32x4_ps(_sum9, 0)); - _mm_store_ps(outptr0 + 4 * 10, _mm512_extractf32x4_ps(_suma, 0)); - _mm_store_ps(outptr0 + 4 * 11, _mm512_extractf32x4_ps(_sumb, 0)); - - _mm_store_ps(outptr0 + out_hstep * 4, _mm512_extractf32x4_ps(_sum0, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 1, _mm512_extractf32x4_ps(_sum1, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 2, _mm512_extractf32x4_ps(_sum2, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 3, _mm512_extractf32x4_ps(_sum3, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 4, _mm512_extractf32x4_ps(_sum4, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 5, _mm512_extractf32x4_ps(_sum5, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 6, _mm512_extractf32x4_ps(_sum6, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 7, _mm512_extractf32x4_ps(_sum7, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 8, _mm512_extractf32x4_ps(_sum8, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 9, _mm512_extractf32x4_ps(_sum9, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 10, _mm512_extractf32x4_ps(_suma, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 11, _mm512_extractf32x4_ps(_sumb, 1)); - - _mm_store_ps(outptr0 + out_hstep * 8, _mm512_extractf32x4_ps(_sum0, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 1, _mm512_extractf32x4_ps(_sum1, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 2, _mm512_extractf32x4_ps(_sum2, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 3, _mm512_extractf32x4_ps(_sum3, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 4, _mm512_extractf32x4_ps(_sum4, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 5, _mm512_extractf32x4_ps(_sum5, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 6, _mm512_extractf32x4_ps(_sum6, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 7, _mm512_extractf32x4_ps(_sum7, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 8, _mm512_extractf32x4_ps(_sum8, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 9, _mm512_extractf32x4_ps(_sum9, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 10, _mm512_extractf32x4_ps(_suma, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 11, _mm512_extractf32x4_ps(_sumb, 2)); - - _mm_store_ps(outptr0 + out_hstep * 12, _mm512_extractf32x4_ps(_sum0, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 1, _mm512_extractf32x4_ps(_sum1, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 2, _mm512_extractf32x4_ps(_sum2, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 3, _mm512_extractf32x4_ps(_sum3, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 4, _mm512_extractf32x4_ps(_sum4, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 5, _mm512_extractf32x4_ps(_sum5, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 6, _mm512_extractf32x4_ps(_sum6, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 7, _mm512_extractf32x4_ps(_sum7, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 8, _mm512_extractf32x4_ps(_sum8, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 9, _mm512_extractf32x4_ps(_sum9, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 10, _mm512_extractf32x4_ps(_suma, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 11, _mm512_extractf32x4_ps(_sumb, 3)); + __m512 _tmp0 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_sum4, _sum5, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_sum6, _sum7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_sum4, _sum5, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_sum6, _sum7, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp8 = _mm512_shuffle_f32x4(_sum8, _sum9, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp9 = _mm512_shuffle_f32x4(_suma, _sumb, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpa = _mm512_shuffle_f32x4(_sum8, _sum9, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpb = _mm512_shuffle_f32x4(_suma, _sumb, _MM_SHUFFLE(3, 2, 3, 2)); + + _sum0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _sum2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _sum4 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _sum5 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _sum6 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _sum7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + _sum8 = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(2, 0, 2, 0)); + _sum9 = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(3, 1, 3, 1)); + _suma = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(2, 0, 2, 0)); + _sumb = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(3, 1, 3, 1)); + + _mm512_storeu_ps(outptr0, _sum0); + _mm512_storeu_ps(outptr0 + 16, _sum4); + _mm512_storeu_ps(outptr0 + 32, _sum8); + _mm512_storeu_ps(outptr0 + out_hstep * 4, _sum1); + _mm512_storeu_ps(outptr0 + out_hstep * 4 + 16, _sum5); + _mm512_storeu_ps(outptr0 + out_hstep * 4 + 32, _sum9); + _mm512_storeu_ps(outptr0 + out_hstep * 8, _sum2); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16, _sum6); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 32, _suma); + _mm512_storeu_ps(outptr0 + out_hstep * 12, _sum3); + _mm512_storeu_ps(outptr0 + out_hstep * 12 + 16, _sum7); + _mm512_storeu_ps(outptr0 + out_hstep * 12 + 32, _sumb); outptr0 += 48; } @@ -2537,74 +2467,54 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } if (out_elempack == 8) { - __m256 _sum0_0 = _mm256_loadu_ps(pC); - __m256 _sum1_0 = _mm256_loadu_ps(pC + 8); - __m256 _sum2_0 = _mm256_loadu_ps(pC + 8 * 2); - __m256 _sum3_0 = _mm256_loadu_ps(pC + 8 * 3); - __m256 _sum4_0 = _mm256_loadu_ps(pC + 8 * 4); - __m256 _sum5_0 = _mm256_loadu_ps(pC + 8 * 5); - __m256 _sum6_0 = _mm256_loadu_ps(pC + 8 * 6); - __m256 _sum7_0 = _mm256_loadu_ps(pC + 8 * 7); - __m256 _sum0_1 = _mm256_loadu_ps(pC + N * 8); - __m256 _sum1_1 = _mm256_loadu_ps(pC + N * 8 + 8); - __m256 _sum2_1 = _mm256_loadu_ps(pC + N * 8 + 8 * 2); - __m256 _sum3_1 = _mm256_loadu_ps(pC + N * 8 + 8 * 3); - __m256 _sum4_1 = _mm256_loadu_ps(pC + N * 8 + 8 * 4); - __m256 _sum5_1 = _mm256_loadu_ps(pC + N * 8 + 8 * 5); - __m256 _sum6_1 = _mm256_loadu_ps(pC + N * 8 + 8 * 6); - __m256 _sum7_1 = _mm256_loadu_ps(pC + N * 8 + 8 * 7); - _sum0 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum0_0), _sum0_1, 1); - _sum1 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum1_0), _sum1_1, 1); - _sum2 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum2_0), _sum2_1, 1); - _sum3 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum3_0), _sum3_1, 1); - _sum4 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum4_0), _sum4_1, 1); - _sum5 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum5_0), _sum5_1, 1); - _sum6 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum6_0), _sum6_1, 1); - _sum7 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum7_0), _sum7_1, 1); + __m512 _tmp0 = _mm512_loadu_ps(pC); + __m512 _tmp1 = _mm512_loadu_ps(pC + 16); + __m512 _tmp2 = _mm512_loadu_ps(pC + 32); + __m512 _tmp3 = _mm512_loadu_ps(pC + 48); + __m512 _tmp4 = _mm512_loadu_ps(pC + N * 8); + __m512 _tmp5 = _mm512_loadu_ps(pC + N * 8 + 16); + __m512 _tmp6 = _mm512_loadu_ps(pC + N * 8 + 32); + __m512 _tmp7 = _mm512_loadu_ps(pC + N * 8 + 48); + + _sum0 = _mm512_shuffle_f32x4(_tmp0, _tmp4, _MM_SHUFFLE(1, 0, 1, 0)); + _sum1 = _mm512_shuffle_f32x4(_tmp0, _tmp4, _MM_SHUFFLE(3, 2, 3, 2)); + _sum2 = _mm512_shuffle_f32x4(_tmp1, _tmp5, _MM_SHUFFLE(1, 0, 1, 0)); + _sum3 = _mm512_shuffle_f32x4(_tmp1, _tmp5, _MM_SHUFFLE(3, 2, 3, 2)); + _sum4 = _mm512_shuffle_f32x4(_tmp2, _tmp6, _MM_SHUFFLE(1, 0, 1, 0)); + _sum5 = _mm512_shuffle_f32x4(_tmp2, _tmp6, _MM_SHUFFLE(3, 2, 3, 2)); + _sum6 = _mm512_shuffle_f32x4(_tmp3, _tmp7, _MM_SHUFFLE(1, 0, 1, 0)); + _sum7 = _mm512_shuffle_f32x4(_tmp3, _tmp7, _MM_SHUFFLE(3, 2, 3, 2)); pC += 64; } if (out_elempack == 4) { - __m128 _sum0_0 = _mm_loadu_ps(pC); - __m128 _sum1_0 = _mm_loadu_ps(pC + 4); - __m128 _sum2_0 = _mm_loadu_ps(pC + 4 * 2); - __m128 _sum3_0 = _mm_loadu_ps(pC + 4 * 3); - __m128 _sum4_0 = _mm_loadu_ps(pC + 4 * 4); - __m128 _sum5_0 = _mm_loadu_ps(pC + 4 * 5); - __m128 _sum6_0 = _mm_loadu_ps(pC + 4 * 6); - __m128 _sum7_0 = _mm_loadu_ps(pC + 4 * 7); - __m128 _sum0_1 = _mm_loadu_ps(pC + N * 4); - __m128 _sum1_1 = _mm_loadu_ps(pC + N * 4 + 4); - __m128 _sum2_1 = _mm_loadu_ps(pC + N * 4 + 4 * 2); - __m128 _sum3_1 = _mm_loadu_ps(pC + N * 4 + 4 * 3); - __m128 _sum4_1 = _mm_loadu_ps(pC + N * 4 + 4 * 4); - __m128 _sum5_1 = _mm_loadu_ps(pC + N * 4 + 4 * 5); - __m128 _sum6_1 = _mm_loadu_ps(pC + N * 4 + 4 * 6); - __m128 _sum7_1 = _mm_loadu_ps(pC + N * 4 + 4 * 7); - __m128 _sum0_2 = _mm_loadu_ps(pC + N * 8); - __m128 _sum1_2 = _mm_loadu_ps(pC + N * 8 + 4); - __m128 _sum2_2 = _mm_loadu_ps(pC + N * 8 + 4 * 2); - __m128 _sum3_2 = _mm_loadu_ps(pC + N * 8 + 4 * 3); - __m128 _sum4_2 = _mm_loadu_ps(pC + N * 8 + 4 * 4); - __m128 _sum5_2 = _mm_loadu_ps(pC + N * 8 + 4 * 5); - __m128 _sum6_2 = _mm_loadu_ps(pC + N * 8 + 4 * 6); - __m128 _sum7_2 = _mm_loadu_ps(pC + N * 8 + 4 * 7); - __m128 _sum0_3 = _mm_loadu_ps(pC + N * 12); - __m128 _sum1_3 = _mm_loadu_ps(pC + N * 12 + 4); - __m128 _sum2_3 = _mm_loadu_ps(pC + N * 12 + 4 * 2); - __m128 _sum3_3 = _mm_loadu_ps(pC + N * 12 + 4 * 3); - __m128 _sum4_3 = _mm_loadu_ps(pC + N * 12 + 4 * 4); - __m128 _sum5_3 = _mm_loadu_ps(pC + N * 12 + 4 * 5); - __m128 _sum6_3 = _mm_loadu_ps(pC + N * 12 + 4 * 6); - __m128 _sum7_3 = _mm_loadu_ps(pC + N * 12 + 4 * 7); - _sum0 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum0_0), _sum0_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum0_2), _sum0_3, 1), 1); - _sum1 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum1_0), _sum1_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum1_2), _sum1_3, 1), 1); - _sum2 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum2_0), _sum2_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum2_2), _sum2_3, 1), 1); - _sum3 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum3_0), _sum3_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum3_2), _sum3_3, 1), 1); - _sum4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum4_0), _sum4_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum4_2), _sum4_3, 1), 1); - _sum5 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum5_0), _sum5_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum5_2), _sum5_3, 1), 1); - _sum6 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum6_0), _sum6_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum6_2), _sum6_3, 1), 1); - _sum7 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum7_0), _sum7_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum7_2), _sum7_3, 1), 1); + _sum0 = _mm512_loadu_ps(pC); + _sum1 = _mm512_loadu_ps(pC + 16); + _sum2 = _mm512_loadu_ps(pC + N * 4); + _sum3 = _mm512_loadu_ps(pC + N * 4 + 16); + _sum4 = _mm512_loadu_ps(pC + N * 8); + _sum5 = _mm512_loadu_ps(pC + N * 8 + 16); + _sum6 = _mm512_loadu_ps(pC + N * 12); + _sum7 = _mm512_loadu_ps(pC + N * 12 + 16); + + __m512 _tmp0 = _mm512_shuffle_f32x4(_sum0, _sum2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_sum4, _sum6, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_sum0, _sum2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_sum4, _sum6, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_sum1, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_sum5, _sum7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_sum1, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_sum5, _sum7, _MM_SHUFFLE(3, 2, 3, 2)); + + _sum0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _sum2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _sum4 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _sum5 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _sum6 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _sum7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + pC += 32; } if (out_elempack == 1) @@ -2700,63 +2610,55 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } if (out_elempack == 8) { - _mm256_store_ps(outptr0, _mm512_extractf32x8_ps(_sum0, 0)); - _mm256_store_ps(outptr0 + 8 * 1, _mm512_extractf32x8_ps(_sum1, 0)); - _mm256_store_ps(outptr0 + 8 * 2, _mm512_extractf32x8_ps(_sum2, 0)); - _mm256_store_ps(outptr0 + 8 * 3, _mm512_extractf32x8_ps(_sum3, 0)); - _mm256_store_ps(outptr0 + 8 * 4, _mm512_extractf32x8_ps(_sum4, 0)); - _mm256_store_ps(outptr0 + 8 * 5, _mm512_extractf32x8_ps(_sum5, 0)); - _mm256_store_ps(outptr0 + 8 * 6, _mm512_extractf32x8_ps(_sum6, 0)); - _mm256_store_ps(outptr0 + 8 * 7, _mm512_extractf32x8_ps(_sum7, 0)); - - _mm256_store_ps(outptr0 + out_hstep * 8, _mm512_extractf32x8_ps(_sum0, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 1, _mm512_extractf32x8_ps(_sum1, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 2, _mm512_extractf32x8_ps(_sum2, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 3, _mm512_extractf32x8_ps(_sum3, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 4, _mm512_extractf32x8_ps(_sum4, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 5, _mm512_extractf32x8_ps(_sum5, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 6, _mm512_extractf32x8_ps(_sum6, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 7, _mm512_extractf32x8_ps(_sum7, 1)); + __m512 _tmp0 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_sum4, _sum5, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_sum6, _sum7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_sum4, _sum5, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_sum6, _sum7, _MM_SHUFFLE(3, 2, 3, 2)); + + _mm512_storeu_ps(outptr0, _tmp0); + _mm512_storeu_ps(outptr0 + 16, _tmp1); + _mm512_storeu_ps(outptr0 + 16 * 2, _tmp2); + _mm512_storeu_ps(outptr0 + 16 * 3, _tmp3); + + _mm512_storeu_ps(outptr0 + out_hstep * 8, _tmp4); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16, _tmp5); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16 * 2, _tmp6); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16 * 3, _tmp7); outptr0 += 64; } if (out_elempack == 4) { - _mm_store_ps(outptr0, _mm512_extractf32x4_ps(_sum0, 0)); - _mm_store_ps(outptr0 + 4 * 1, _mm512_extractf32x4_ps(_sum1, 0)); - _mm_store_ps(outptr0 + 4 * 2, _mm512_extractf32x4_ps(_sum2, 0)); - _mm_store_ps(outptr0 + 4 * 3, _mm512_extractf32x4_ps(_sum3, 0)); - _mm_store_ps(outptr0 + 4 * 4, _mm512_extractf32x4_ps(_sum4, 0)); - _mm_store_ps(outptr0 + 4 * 5, _mm512_extractf32x4_ps(_sum5, 0)); - _mm_store_ps(outptr0 + 4 * 6, _mm512_extractf32x4_ps(_sum6, 0)); - _mm_store_ps(outptr0 + 4 * 7, _mm512_extractf32x4_ps(_sum7, 0)); - - _mm_store_ps(outptr0 + out_hstep * 4, _mm512_extractf32x4_ps(_sum0, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 1, _mm512_extractf32x4_ps(_sum1, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 2, _mm512_extractf32x4_ps(_sum2, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 3, _mm512_extractf32x4_ps(_sum3, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 4, _mm512_extractf32x4_ps(_sum4, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 5, _mm512_extractf32x4_ps(_sum5, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 6, _mm512_extractf32x4_ps(_sum6, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 7, _mm512_extractf32x4_ps(_sum7, 1)); - - _mm_store_ps(outptr0 + out_hstep * 8, _mm512_extractf32x4_ps(_sum0, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 1, _mm512_extractf32x4_ps(_sum1, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 2, _mm512_extractf32x4_ps(_sum2, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 3, _mm512_extractf32x4_ps(_sum3, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 4, _mm512_extractf32x4_ps(_sum4, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 5, _mm512_extractf32x4_ps(_sum5, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 6, _mm512_extractf32x4_ps(_sum6, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 7, _mm512_extractf32x4_ps(_sum7, 2)); - - _mm_store_ps(outptr0 + out_hstep * 12, _mm512_extractf32x4_ps(_sum0, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 1, _mm512_extractf32x4_ps(_sum1, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 2, _mm512_extractf32x4_ps(_sum2, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 3, _mm512_extractf32x4_ps(_sum3, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 4, _mm512_extractf32x4_ps(_sum4, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 5, _mm512_extractf32x4_ps(_sum5, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 6, _mm512_extractf32x4_ps(_sum6, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 7, _mm512_extractf32x4_ps(_sum7, 3)); + __m512 _tmp0 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_sum4, _sum5, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_sum6, _sum7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_sum4, _sum5, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_sum6, _sum7, _MM_SHUFFLE(3, 2, 3, 2)); + + _sum0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _sum2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _sum4 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _sum5 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _sum6 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _sum7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + + _mm512_storeu_ps(outptr0, _sum0); + _mm512_storeu_ps(outptr0 + 16, _sum4); + _mm512_storeu_ps(outptr0 + out_hstep * 4, _sum1); + _mm512_storeu_ps(outptr0 + out_hstep * 4 + 16, _sum5); + _mm512_storeu_ps(outptr0 + out_hstep * 8, _sum2); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16, _sum6); + _mm512_storeu_ps(outptr0 + out_hstep * 12, _sum3); + _mm512_storeu_ps(outptr0 + out_hstep * 12 + 16, _sum7); outptr0 += 32; } @@ -2840,42 +2742,32 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } if (out_elempack == 8) { - __m256 _sum0_0 = _mm256_loadu_ps(pC); - __m256 _sum1_0 = _mm256_loadu_ps(pC + 8); - __m256 _sum2_0 = _mm256_loadu_ps(pC + 16); - __m256 _sum3_0 = _mm256_loadu_ps(pC + 24); - __m256 _sum0_1 = _mm256_loadu_ps(pC + N * 8); - __m256 _sum1_1 = _mm256_loadu_ps(pC + N * 8 + 8); - __m256 _sum2_1 = _mm256_loadu_ps(pC + N * 8 + 16); - __m256 _sum3_1 = _mm256_loadu_ps(pC + N * 8 + 24); - _sum0 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum0_0), _sum0_1, 1); - _sum1 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum1_0), _sum1_1, 1); - _sum2 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum2_0), _sum2_1, 1); - _sum3 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum3_0), _sum3_1, 1); + __m512 _tmp0 = _mm512_loadu_ps(pC); + __m512 _tmp1 = _mm512_loadu_ps(pC + 16); + __m512 _tmp2 = _mm512_loadu_ps(pC + N * 8); + __m512 _tmp3 = _mm512_loadu_ps(pC + N * 8 + 16); + _sum0 = _mm512_shuffle_f32x4(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + _sum1 = _mm512_shuffle_f32x4(_tmp0, _tmp2, _MM_SHUFFLE(3, 2, 3, 2)); + _sum2 = _mm512_shuffle_f32x4(_tmp1, _tmp3, _MM_SHUFFLE(1, 0, 1, 0)); + _sum3 = _mm512_shuffle_f32x4(_tmp1, _tmp3, _MM_SHUFFLE(3, 2, 3, 2)); pC += 32; } if (out_elempack == 4) { - __m128 _sum0_0 = _mm_loadu_ps(pC); - __m128 _sum1_0 = _mm_loadu_ps(pC + 4); - __m128 _sum2_0 = _mm_loadu_ps(pC + 8); - __m128 _sum3_0 = _mm_loadu_ps(pC + 12); - __m128 _sum0_1 = _mm_loadu_ps(pC + N * 4); - __m128 _sum1_1 = _mm_loadu_ps(pC + N * 4 + 4); - __m128 _sum2_1 = _mm_loadu_ps(pC + N * 4 + 8); - __m128 _sum3_1 = _mm_loadu_ps(pC + N * 4 + 12); - __m128 _sum0_2 = _mm_loadu_ps(pC + N * 8); - __m128 _sum1_2 = _mm_loadu_ps(pC + N * 8 + 4); - __m128 _sum2_2 = _mm_loadu_ps(pC + N * 8 + 8); - __m128 _sum3_2 = _mm_loadu_ps(pC + N * 8 + 12); - __m128 _sum0_3 = _mm_loadu_ps(pC + N * 12); - __m128 _sum1_3 = _mm_loadu_ps(pC + N * 12 + 4); - __m128 _sum2_3 = _mm_loadu_ps(pC + N * 12 + 8); - __m128 _sum3_3 = _mm_loadu_ps(pC + N * 12 + 12); - _sum0 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum0_0), _sum0_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum0_2), _sum0_3, 1), 1); - _sum1 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum1_0), _sum1_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum1_2), _sum1_3, 1), 1); - _sum2 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum2_0), _sum2_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum2_2), _sum2_3, 1), 1); - _sum3 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_sum3_0), _sum3_1, 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_sum3_2), _sum3_3, 1), 1); + _sum0 = _mm512_loadu_ps(pC); + _sum1 = _mm512_loadu_ps(pC + N * 4); + _sum2 = _mm512_loadu_ps(pC + N * 8); + _sum3 = _mm512_loadu_ps(pC + N * 12); + + __m512 _tmp0 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); + + _sum0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _sum2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); pC += 16; } if (out_elempack == 1) @@ -2954,39 +2846,35 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } if (out_elempack == 8) { - _mm256_store_ps(outptr0, _mm512_extractf32x8_ps(_sum0, 0)); - _mm256_store_ps(outptr0 + 8 * 1, _mm512_extractf32x8_ps(_sum1, 0)); - _mm256_store_ps(outptr0 + 8 * 2, _mm512_extractf32x8_ps(_sum2, 0)); - _mm256_store_ps(outptr0 + 8 * 3, _mm512_extractf32x8_ps(_sum3, 0)); + __m512 _tmp0 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); - _mm256_store_ps(outptr0 + out_hstep * 8, _mm512_extractf32x8_ps(_sum0, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 1, _mm512_extractf32x8_ps(_sum1, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 2, _mm512_extractf32x8_ps(_sum2, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8 * 3, _mm512_extractf32x8_ps(_sum3, 1)); + _mm512_storeu_ps(outptr0, _tmp0); + _mm512_storeu_ps(outptr0 + 16, _tmp1); + + _mm512_storeu_ps(outptr0 + out_hstep * 8, _tmp2); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16, _tmp3); outptr0 += 32; } if (out_elempack == 4) { - _mm_store_ps(outptr0, _mm512_extractf32x4_ps(_sum0, 0)); - _mm_store_ps(outptr0 + 4 * 1, _mm512_extractf32x4_ps(_sum1, 0)); - _mm_store_ps(outptr0 + 4 * 2, _mm512_extractf32x4_ps(_sum2, 0)); - _mm_store_ps(outptr0 + 4 * 3, _mm512_extractf32x4_ps(_sum3, 0)); + __m512 _tmp0 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); - _mm_store_ps(outptr0 + out_hstep * 4, _mm512_extractf32x4_ps(_sum0, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 1, _mm512_extractf32x4_ps(_sum1, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 2, _mm512_extractf32x4_ps(_sum2, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 3, _mm512_extractf32x4_ps(_sum3, 1)); + _sum0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _sum2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); - _mm_store_ps(outptr0 + out_hstep * 8, _mm512_extractf32x4_ps(_sum0, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 1, _mm512_extractf32x4_ps(_sum1, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 2, _mm512_extractf32x4_ps(_sum2, 2)); - _mm_store_ps(outptr0 + out_hstep * 8 + 4 * 3, _mm512_extractf32x4_ps(_sum3, 2)); - - _mm_store_ps(outptr0 + out_hstep * 12, _mm512_extractf32x4_ps(_sum0, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 1, _mm512_extractf32x4_ps(_sum1, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 2, _mm512_extractf32x4_ps(_sum2, 3)); - _mm_store_ps(outptr0 + out_hstep * 12 + 4 * 3, _mm512_extractf32x4_ps(_sum3, 3)); + _mm512_storeu_ps(outptr0, _sum0); + _mm512_storeu_ps(outptr0 + out_hstep * 4, _sum1); + _mm512_storeu_ps(outptr0 + out_hstep * 8, _sum2); + _mm512_storeu_ps(outptr0 + out_hstep * 12, _sum3); outptr0 += 16; } @@ -3076,12 +2964,10 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } if (out_elempack == 8) { - __m256 _sum0_0 = _mm256_loadu_ps(pC); - __m256 _sum1_0 = _mm256_loadu_ps(pC + 8); - __m256 _sum0_1 = _mm256_loadu_ps(pC + N * 8); - __m256 _sum1_1 = _mm256_loadu_ps(pC + N * 8 + 8); - _sum0 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum0_0), _sum0_1, 1); - _sum1 = _mm512_insertf32x8(_mm512_castps256_ps512(_sum1_0), _sum1_1, 1); + __m512 _tmp0 = _mm512_loadu_ps(pC); + __m512 _tmp1 = _mm512_loadu_ps(pC + N * 8); + _sum0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(1, 0, 1, 0)); + _sum1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 2, 3, 2)); pC += 16; } if (out_elempack == 4) @@ -3177,11 +3063,12 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } if (out_elempack == 8) { - _mm256_store_ps(outptr0, _mm512_extractf32x8_ps(_sum0, 0)); - _mm256_store_ps(outptr0 + 8, _mm512_extractf32x8_ps(_sum1, 0)); + __m512 _tmp0 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + + _mm512_storeu_ps(outptr0, _tmp0); + _mm512_storeu_ps(outptr0 + out_hstep * 8, _tmp1); - _mm256_store_ps(outptr0 + out_hstep * 8, _mm512_extractf32x8_ps(_sum0, 1)); - _mm256_store_ps(outptr0 + out_hstep * 8 + 8, _mm512_extractf32x8_ps(_sum1, 1)); outptr0 += 16; } if (out_elempack == 4) @@ -3501,42 +3388,31 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } if (out_elempack == 4) { - __m128 _sum0_0 = _mm_loadu_ps(pC); - __m128 _sum1_0 = _mm_loadu_ps(pC + 4); - __m128 _sum2_0 = _mm_loadu_ps(pC + 4 * 2); - __m128 _sum3_0 = _mm_loadu_ps(pC + 4 * 3); - __m128 _sum4_0 = _mm_loadu_ps(pC + 4 * 4); - __m128 _sum5_0 = _mm_loadu_ps(pC + 4 * 5); - __m128 _sum6_0 = _mm_loadu_ps(pC + 4 * 6); - __m128 _sum7_0 = _mm_loadu_ps(pC + 4 * 7); - __m128 _sum8_0 = _mm_loadu_ps(pC + 4 * 8); - __m128 _sum9_0 = _mm_loadu_ps(pC + 4 * 9); - __m128 _suma_0 = _mm_loadu_ps(pC + 4 * 10); - __m128 _sumb_0 = _mm_loadu_ps(pC + 4 * 11); - __m128 _sum0_1 = _mm_loadu_ps(pC + N * 4); - __m128 _sum1_1 = _mm_loadu_ps(pC + N * 4 + 4); - __m128 _sum2_1 = _mm_loadu_ps(pC + N * 4 + 4 * 2); - __m128 _sum3_1 = _mm_loadu_ps(pC + N * 4 + 4 * 3); - __m128 _sum4_1 = _mm_loadu_ps(pC + N * 4 + 4 * 4); - __m128 _sum5_1 = _mm_loadu_ps(pC + N * 4 + 4 * 5); - __m128 _sum6_1 = _mm_loadu_ps(pC + N * 4 + 4 * 6); - __m128 _sum7_1 = _mm_loadu_ps(pC + N * 4 + 4 * 7); - __m128 _sum8_1 = _mm_loadu_ps(pC + N * 4 + 4 * 8); - __m128 _sum9_1 = _mm_loadu_ps(pC + N * 4 + 4 * 9); - __m128 _suma_1 = _mm_loadu_ps(pC + N * 4 + 4 * 10); - __m128 _sumb_1 = _mm_loadu_ps(pC + N * 4 + 4 * 11); - _sum0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum0_0), _sum0_1, 1); - _sum1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum1_0), _sum1_1, 1); - _sum2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum2_0), _sum2_1, 1); - _sum3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum3_0), _sum3_1, 1); - _sum4 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum4_0), _sum4_1, 1); - _sum5 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum5_0), _sum5_1, 1); - _sum6 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum6_0), _sum6_1, 1); - _sum7 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum7_0), _sum7_1, 1); - _sum8 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum8_0), _sum8_1, 1); - _sum9 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum9_0), _sum9_1, 1); - _suma = _mm256_insertf128_ps(_mm256_castps128_ps256(_suma_0), _suma_1, 1); - _sumb = _mm256_insertf128_ps(_mm256_castps128_ps256(_sumb_0), _sumb_1, 1); + __m256 _tmp0 = _mm256_loadu_ps(pC); + __m256 _tmp1 = _mm256_loadu_ps(pC + 8); + __m256 _tmp2 = _mm256_loadu_ps(pC + 8 * 2); + __m256 _tmp3 = _mm256_loadu_ps(pC + 8 * 3); + __m256 _tmp4 = _mm256_loadu_ps(pC + 8 * 4); + __m256 _tmp5 = _mm256_loadu_ps(pC + 8 * 5); + __m256 _tmp6 = _mm256_loadu_ps(pC + N * 4); + __m256 _tmp7 = _mm256_loadu_ps(pC + N * 4 + 8); + __m256 _tmp8 = _mm256_loadu_ps(pC + N * 4 + 8 * 2); + __m256 _tmp9 = _mm256_loadu_ps(pC + N * 4 + 8 * 3); + __m256 _tmpa = _mm256_loadu_ps(pC + N * 4 + 8 * 4); + __m256 _tmpb = _mm256_loadu_ps(pC + N * 4 + 8 * 5); + + _sum0 = _mm256_permute2f128_ps(_tmp0, _tmp6, _MM_SHUFFLE(0, 2, 0, 0)); + _sum1 = _mm256_permute2f128_ps(_tmp0, _tmp6, _MM_SHUFFLE(0, 3, 0, 1)); + _sum2 = _mm256_permute2f128_ps(_tmp1, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); + _sum3 = _mm256_permute2f128_ps(_tmp1, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); + _sum4 = _mm256_permute2f128_ps(_tmp2, _tmp8, _MM_SHUFFLE(0, 2, 0, 0)); + _sum5 = _mm256_permute2f128_ps(_tmp2, _tmp8, _MM_SHUFFLE(0, 3, 0, 1)); + _sum6 = _mm256_permute2f128_ps(_tmp3, _tmp9, _MM_SHUFFLE(0, 2, 0, 0)); + _sum7 = _mm256_permute2f128_ps(_tmp3, _tmp9, _MM_SHUFFLE(0, 3, 0, 1)); + _sum8 = _mm256_permute2f128_ps(_tmp4, _tmpa, _MM_SHUFFLE(0, 2, 0, 0)); + _sum9 = _mm256_permute2f128_ps(_tmp4, _tmpa, _MM_SHUFFLE(0, 3, 0, 1)); + _suma = _mm256_permute2f128_ps(_tmp5, _tmpb, _MM_SHUFFLE(0, 2, 0, 0)); + _sumb = _mm256_permute2f128_ps(_tmp5, _tmpb, _MM_SHUFFLE(0, 3, 0, 1)); pC += 48; } if (out_elempack == 1) @@ -3648,31 +3524,32 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } if (out_elempack == 4) { - _mm_store_ps(outptr0, _mm256_extractf128_ps(_sum0, 0)); - _mm_store_ps(outptr0 + 4 * 1, _mm256_extractf128_ps(_sum1, 0)); - _mm_store_ps(outptr0 + 4 * 2, _mm256_extractf128_ps(_sum2, 0)); - _mm_store_ps(outptr0 + 4 * 3, _mm256_extractf128_ps(_sum3, 0)); - _mm_store_ps(outptr0 + 4 * 4, _mm256_extractf128_ps(_sum4, 0)); - _mm_store_ps(outptr0 + 4 * 5, _mm256_extractf128_ps(_sum5, 0)); - _mm_store_ps(outptr0 + 4 * 6, _mm256_extractf128_ps(_sum6, 0)); - _mm_store_ps(outptr0 + 4 * 7, _mm256_extractf128_ps(_sum7, 0)); - _mm_store_ps(outptr0 + 4 * 8, _mm256_extractf128_ps(_sum8, 0)); - _mm_store_ps(outptr0 + 4 * 9, _mm256_extractf128_ps(_sum9, 0)); - _mm_store_ps(outptr0 + 4 * 10, _mm256_extractf128_ps(_suma, 0)); - _mm_store_ps(outptr0 + 4 * 11, _mm256_extractf128_ps(_sumb, 0)); - - _mm_store_ps(outptr0 + out_hstep * 4, _mm256_extractf128_ps(_sum0, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 1, _mm256_extractf128_ps(_sum1, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 2, _mm256_extractf128_ps(_sum2, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 3, _mm256_extractf128_ps(_sum3, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 4, _mm256_extractf128_ps(_sum4, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 5, _mm256_extractf128_ps(_sum5, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 6, _mm256_extractf128_ps(_sum6, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 7, _mm256_extractf128_ps(_sum7, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 8, _mm256_extractf128_ps(_sum8, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 9, _mm256_extractf128_ps(_sum9, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 10, _mm256_extractf128_ps(_suma, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 11, _mm256_extractf128_ps(_sumb, 1)); + __m256 _tmp0 = _mm256_permute2f128_ps(_sum0, _sum1, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_sum2, _sum3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp2 = _mm256_permute2f128_ps(_sum4, _sum5, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp3 = _mm256_permute2f128_ps(_sum6, _sum7, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp4 = _mm256_permute2f128_ps(_sum8, _sum9, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp5 = _mm256_permute2f128_ps(_suma, _sumb, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp6 = _mm256_permute2f128_ps(_sum0, _sum1, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp7 = _mm256_permute2f128_ps(_sum2, _sum3, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp8 = _mm256_permute2f128_ps(_sum4, _sum5, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp9 = _mm256_permute2f128_ps(_sum6, _sum7, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmpa = _mm256_permute2f128_ps(_sum8, _sum9, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmpb = _mm256_permute2f128_ps(_suma, _sumb, _MM_SHUFFLE(0, 3, 0, 1)); + + _mm256_storeu_ps(outptr0, _tmp0); + _mm256_storeu_ps(outptr0 + 8, _tmp1); + _mm256_storeu_ps(outptr0 + 8 * 2, _tmp2); + _mm256_storeu_ps(outptr0 + 8 * 3, _tmp3); + _mm256_storeu_ps(outptr0 + 8 * 4, _tmp4); + _mm256_storeu_ps(outptr0 + 8 * 5, _tmp5); + + _mm256_storeu_ps(outptr0 + out_hstep * 4, _tmp6); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8, _tmp7); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8 * 2, _tmp8); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8 * 3, _tmp9); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8 * 4, _tmpa); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8 * 5, _tmpb); outptr0 += 48; } @@ -3793,30 +3670,23 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } if (out_elempack == 4) { - __m128 _sum0_0 = _mm_loadu_ps(pC); - __m128 _sum1_0 = _mm_loadu_ps(pC + 4); - __m128 _sum2_0 = _mm_loadu_ps(pC + 4 * 2); - __m128 _sum3_0 = _mm_loadu_ps(pC + 4 * 3); - __m128 _sum4_0 = _mm_loadu_ps(pC + 4 * 4); - __m128 _sum5_0 = _mm_loadu_ps(pC + 4 * 5); - __m128 _sum6_0 = _mm_loadu_ps(pC + 4 * 6); - __m128 _sum7_0 = _mm_loadu_ps(pC + 4 * 7); - __m128 _sum0_1 = _mm_loadu_ps(pC + N * 4); - __m128 _sum1_1 = _mm_loadu_ps(pC + N * 4 + 4); - __m128 _sum2_1 = _mm_loadu_ps(pC + N * 4 + 4 * 2); - __m128 _sum3_1 = _mm_loadu_ps(pC + N * 4 + 4 * 3); - __m128 _sum4_1 = _mm_loadu_ps(pC + N * 4 + 4 * 4); - __m128 _sum5_1 = _mm_loadu_ps(pC + N * 4 + 4 * 5); - __m128 _sum6_1 = _mm_loadu_ps(pC + N * 4 + 4 * 6); - __m128 _sum7_1 = _mm_loadu_ps(pC + N * 4 + 4 * 7); - _sum0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum0_0), _sum0_1, 1); - _sum1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum1_0), _sum1_1, 1); - _sum2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum2_0), _sum2_1, 1); - _sum3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum3_0), _sum3_1, 1); - _sum4 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum4_0), _sum4_1, 1); - _sum5 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum5_0), _sum5_1, 1); - _sum6 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum6_0), _sum6_1, 1); - _sum7 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum7_0), _sum7_1, 1); + __m256 _tmp0 = _mm256_loadu_ps(pC); + __m256 _tmp1 = _mm256_loadu_ps(pC + 8); + __m256 _tmp2 = _mm256_loadu_ps(pC + 8 * 2); + __m256 _tmp3 = _mm256_loadu_ps(pC + 8 * 3); + __m256 _tmp4 = _mm256_loadu_ps(pC + N * 4); + __m256 _tmp5 = _mm256_loadu_ps(pC + N * 4 + 8); + __m256 _tmp6 = _mm256_loadu_ps(pC + N * 4 + 8 * 2); + __m256 _tmp7 = _mm256_loadu_ps(pC + N * 4 + 8 * 3); + + _sum0 = _mm256_permute2f128_ps(_tmp0, _tmp4, _MM_SHUFFLE(0, 2, 0, 0)); + _sum1 = _mm256_permute2f128_ps(_tmp0, _tmp4, _MM_SHUFFLE(0, 3, 0, 1)); + _sum2 = _mm256_permute2f128_ps(_tmp1, _tmp5, _MM_SHUFFLE(0, 2, 0, 0)); + _sum3 = _mm256_permute2f128_ps(_tmp1, _tmp5, _MM_SHUFFLE(0, 3, 0, 1)); + _sum4 = _mm256_permute2f128_ps(_tmp2, _tmp6, _MM_SHUFFLE(0, 2, 0, 0)); + _sum5 = _mm256_permute2f128_ps(_tmp2, _tmp6, _MM_SHUFFLE(0, 3, 0, 1)); + _sum6 = _mm256_permute2f128_ps(_tmp3, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); + _sum7 = _mm256_permute2f128_ps(_tmp3, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); pC += 32; } if (out_elempack == 1) @@ -3895,23 +3765,24 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } if (out_elempack == 4) { - _mm_store_ps(outptr0, _mm256_extractf128_ps(_sum0, 0)); - _mm_store_ps(outptr0 + 4 * 1, _mm256_extractf128_ps(_sum1, 0)); - _mm_store_ps(outptr0 + 4 * 2, _mm256_extractf128_ps(_sum2, 0)); - _mm_store_ps(outptr0 + 4 * 3, _mm256_extractf128_ps(_sum3, 0)); - _mm_store_ps(outptr0 + 4 * 4, _mm256_extractf128_ps(_sum4, 0)); - _mm_store_ps(outptr0 + 4 * 5, _mm256_extractf128_ps(_sum5, 0)); - _mm_store_ps(outptr0 + 4 * 6, _mm256_extractf128_ps(_sum6, 0)); - _mm_store_ps(outptr0 + 4 * 7, _mm256_extractf128_ps(_sum7, 0)); - - _mm_store_ps(outptr0 + out_hstep * 4, _mm256_extractf128_ps(_sum0, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 1, _mm256_extractf128_ps(_sum1, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 2, _mm256_extractf128_ps(_sum2, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 3, _mm256_extractf128_ps(_sum3, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 4, _mm256_extractf128_ps(_sum4, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 5, _mm256_extractf128_ps(_sum5, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 6, _mm256_extractf128_ps(_sum6, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 7, _mm256_extractf128_ps(_sum7, 1)); + __m256 _tmp0 = _mm256_permute2f128_ps(_sum0, _sum1, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_sum2, _sum3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp2 = _mm256_permute2f128_ps(_sum4, _sum5, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp3 = _mm256_permute2f128_ps(_sum6, _sum7, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp4 = _mm256_permute2f128_ps(_sum0, _sum1, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp5 = _mm256_permute2f128_ps(_sum2, _sum3, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp6 = _mm256_permute2f128_ps(_sum4, _sum5, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp7 = _mm256_permute2f128_ps(_sum6, _sum7, _MM_SHUFFLE(0, 3, 0, 1)); + + _mm256_storeu_ps(outptr0, _tmp0); + _mm256_storeu_ps(outptr0 + 8, _tmp1); + _mm256_storeu_ps(outptr0 + 8 * 2, _tmp2); + _mm256_storeu_ps(outptr0 + 8 * 3, _tmp3); + + _mm256_storeu_ps(outptr0 + out_hstep * 4, _tmp4); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8, _tmp5); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8 * 2, _tmp6); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8 * 3, _tmp7); outptr0 += 32; } @@ -3987,18 +3858,15 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } if (out_elempack == 4) { - __m128 _sum0_0 = _mm_loadu_ps(pC); - __m128 _sum1_0 = _mm_loadu_ps(pC + 4); - __m128 _sum2_0 = _mm_loadu_ps(pC + 8); - __m128 _sum3_0 = _mm_loadu_ps(pC + 12); - __m128 _sum0_1 = _mm_loadu_ps(pC + N * 4); - __m128 _sum1_1 = _mm_loadu_ps(pC + N * 4 + 4); - __m128 _sum2_1 = _mm_loadu_ps(pC + N * 4 + 8); - __m128 _sum3_1 = _mm_loadu_ps(pC + N * 4 + 12); - _sum0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum0_0), _sum0_1, 1); - _sum1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum1_0), _sum1_1, 1); - _sum2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum2_0), _sum2_1, 1); - _sum3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum3_0), _sum3_1, 1); + __m256 _tmp0 = _mm256_loadu_ps(pC); + __m256 _tmp1 = _mm256_loadu_ps(pC + 8); + __m256 _tmp2 = _mm256_loadu_ps(pC + N * 4); + __m256 _tmp3 = _mm256_loadu_ps(pC + N * 4 + 8); + + _sum0 = _mm256_permute2f128_ps(_tmp0, _tmp2, _MM_SHUFFLE(0, 2, 0, 0)); + _sum1 = _mm256_permute2f128_ps(_tmp0, _tmp2, _MM_SHUFFLE(0, 3, 0, 1)); + _sum2 = _mm256_permute2f128_ps(_tmp1, _tmp3, _MM_SHUFFLE(0, 2, 0, 0)); + _sum3 = _mm256_permute2f128_ps(_tmp1, _tmp3, _MM_SHUFFLE(0, 3, 0, 1)); pC += 16; } if (out_elempack == 1) @@ -4067,15 +3935,16 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } if (out_elempack == 4) { - _mm_store_ps(outptr0, _mm256_extractf128_ps(_sum0, 0)); - _mm_store_ps(outptr0 + 4 * 1, _mm256_extractf128_ps(_sum1, 0)); - _mm_store_ps(outptr0 + 4 * 2, _mm256_extractf128_ps(_sum2, 0)); - _mm_store_ps(outptr0 + 4 * 3, _mm256_extractf128_ps(_sum3, 0)); + __m256 _tmp0 = _mm256_permute2f128_ps(_sum0, _sum1, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_sum2, _sum3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp2 = _mm256_permute2f128_ps(_sum0, _sum1, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp3 = _mm256_permute2f128_ps(_sum2, _sum3, _MM_SHUFFLE(0, 3, 0, 1)); - _mm_store_ps(outptr0 + out_hstep * 4, _mm256_extractf128_ps(_sum0, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 1, _mm256_extractf128_ps(_sum1, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 2, _mm256_extractf128_ps(_sum2, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4 * 3, _mm256_extractf128_ps(_sum3, 1)); + _mm256_storeu_ps(outptr0, _tmp0); + _mm256_storeu_ps(outptr0 + 8, _tmp1); + + _mm256_storeu_ps(outptr0 + out_hstep * 4, _tmp2); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8, _tmp3); outptr0 += 16; } @@ -4147,12 +4016,10 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } if (out_elempack == 4) { - __m128 _sum0_0 = _mm_loadu_ps(pC); - __m128 _sum1_0 = _mm_loadu_ps(pC + 4); - __m128 _sum0_1 = _mm_loadu_ps(pC + N * 4); - __m128 _sum1_1 = _mm_loadu_ps(pC + N * 4 + 4); - _sum0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum0_0), _sum0_1, 1); - _sum1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_sum1_0), _sum1_1, 1); + __m256 _tmp0 = _mm256_loadu_ps(pC); + __m256 _tmp1 = _mm256_loadu_ps(pC + N * 4); + _sum0 = _mm256_permute2f128_ps(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); + _sum1 = _mm256_permute2f128_ps(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); pC += 8; } if (out_elempack == 1) @@ -4218,11 +4085,11 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } if (out_elempack == 4) { - _mm_store_ps(outptr0, _mm256_extractf128_ps(_sum0, 0)); - _mm_store_ps(outptr0 + 4, _mm256_extractf128_ps(_sum1, 0)); + __m256 _tmp0 = _mm256_permute2f128_ps(_sum0, _sum1, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_sum0, _sum1, _MM_SHUFFLE(0, 3, 0, 1)); - _mm_store_ps(outptr0 + out_hstep * 4, _mm256_extractf128_ps(_sum0, 1)); - _mm_store_ps(outptr0 + out_hstep * 4 + 4, _mm256_extractf128_ps(_sum1, 1)); + _mm256_storeu_ps(outptr0, _tmp0); + _mm256_storeu_ps(outptr0 + out_hstep * 4, _tmp1); outptr0 += 8; } if (out_elempack == 1) @@ -5898,15 +5765,15 @@ static void get_optimal_tile_mnk(int M, int N, int K, int& TILE_M, int& TILE_N, #if __AVX512F__ TILE_M = tile_size / 16 * 16; - TILE_N = tile_size; + TILE_N = tile_size / 4 * 4; TILE_K = tile_size / 16 * 16; #elif __AVX__ TILE_M = tile_size / 8 * 8; - TILE_N = tile_size; + TILE_N = tile_size / 4 * 4; TILE_K = tile_size / 8 * 8; #elif __SSE2__ TILE_M = tile_size / 4 * 4; - TILE_N = tile_size; + TILE_N = tile_size / 4 * 4; TILE_K = tile_size / 4 * 4; #else TILE_M = tile_size / 2 * 2; @@ -5933,13 +5800,13 @@ static void get_optimal_tile_mnk(int M, int N, int K, int& TILE_M, int& TILE_N, #if __AVX512F__ TILE_M = tile_size / 16 * 16; - TILE_N = tile_size; + TILE_N = tile_size / 4 * 4; #elif __AVX__ TILE_M = tile_size / 8 * 8; - TILE_N = tile_size; + TILE_N = tile_size / 4 * 4; #elif __SSE2__ TILE_M = tile_size / 4 * 4; - TILE_N = tile_size; + TILE_N = tile_size / 4 * 4; #else TILE_M = tile_size / 2 * 2; TILE_N = tile_size; @@ -5966,7 +5833,15 @@ static void get_optimal_tile_mnk(int M, int N, int K, int& TILE_M, int& TILE_N, if (N > 0) { int nn_N = (N + TILE_N - 1) / TILE_N; +#if __AVX512F__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#elif __AVX__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#elif __SSE2__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#else TILE_N = std::min(TILE_N, (N + nn_N - 1) / nn_N); +#endif } if (nT > 1)