|
|
|
@@ -29,6 +29,13 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
|
|
|
// Include common macros for BF16 based operations with IA intrinsics |
|
|
|
#include "bf16_common_macros.h" |
|
|
|
|
|
|
|
#undef STORE16_COMPLETE_RESULT |
|
|
|
#undef STORE16_MASK_COMPLETE_RESULT |
|
|
|
#undef STORE8_COMPLETE_RESULT |
|
|
|
#undef STORE8_MASK_COMPLETE_RESULT |
|
|
|
#undef STORE4_COMPLETE_RESULT |
|
|
|
#undef STORE4_MASK_COMPLETE_RESULT |
|
|
|
|
|
|
|
#ifndef ZERO_BETA // Beta is non-zero |
|
|
|
|
|
|
|
#ifndef ONE_BETA // BETA is not ONE |
|
|
|
@@ -231,7 +238,9 @@ static int sbgemv_kernel_32x2(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
unsigned char load_mask_value = (((unsigned char)0xff) >> 6); |
|
|
|
@@ -280,7 +289,7 @@ static int sbgemv_kernel_32x2(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, |
|
|
|
} else if (tail_num == 8) { |
|
|
|
__m256 result256 = _mm256_setzero_ps(); |
|
|
|
|
|
|
|
__m256i matrixArray256 = _mm256_loadu_si256(&a[(tag_m_32x)*2]); // Load 8 rows with n=2 |
|
|
|
__m256i matrixArray256 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*2]); // Load 8 rows with n=2 |
|
|
|
__m256i xArray256 = _mm512_castsi512_si256(xArray); |
|
|
|
result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) xArray256); |
|
|
|
|
|
|
|
@@ -323,7 +332,9 @@ static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
unsigned char x_load_mask_value = (((unsigned char)0xff) >> 5); |
|
|
|
@@ -395,9 +406,9 @@ static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, |
|
|
|
result256_0 = _mm256_setzero_ps(); |
|
|
|
result256_1 = _mm256_setzero_ps(); |
|
|
|
|
|
|
|
matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element |
|
|
|
matrixArray256_1 = _mm256_loadu_si256(&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element |
|
|
|
matrixArray256_2 = _mm256_loadu_si256(&a[((tag_m_32x+10)*3 + 2)]); // Load 5 rows with n=3 plus 1 element |
|
|
|
matrixArray256_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element |
|
|
|
matrixArray256_1 = _mm256_loadu_si256((__m256i *)&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element |
|
|
|
matrixArray256_2 = _mm256_loadu_si256((__m256i *)&a[((tag_m_32x+10)*3 + 2)]); // Load 5 rows with n=3 plus 1 element |
|
|
|
|
|
|
|
matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row |
|
|
|
matrixArray256_4 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx01_2nd, matrixArray256_2); // Select the first 2 elements for each row |
|
|
|
@@ -423,8 +434,8 @@ static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, |
|
|
|
if (tail_num > 10) { |
|
|
|
unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-((tail_num-10-1)*3+1))); |
|
|
|
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value); |
|
|
|
matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element |
|
|
|
matrixArray256_1 = _mm256_loadu_si256(&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element |
|
|
|
matrixArray256_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element |
|
|
|
matrixArray256_1 = _mm256_loadu_si256((__m256i *)&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element |
|
|
|
matrixArray256_2 = _mm256_maskz_loadu_epi16(tail_mask, &a[((tag_m_32x+10)*3 + 2)]); // Load m-tag_m_32x-10 rows |
|
|
|
|
|
|
|
matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row |
|
|
|
@@ -439,7 +450,7 @@ static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, |
|
|
|
} else if (tail_num > 5) { |
|
|
|
unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-((tail_num-5-1)*3+2))); |
|
|
|
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value); |
|
|
|
matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element |
|
|
|
matrixArray256_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element |
|
|
|
matrixArray256_1 = _mm256_maskz_loadu_epi16(tail_mask, &a[((tag_m_32x+5)*3+1)]); // Load m-tag_m_32x-5 rows |
|
|
|
matrixArray256_2 = _mm256_setzero_si256(); |
|
|
|
|
|
|
|
@@ -499,7 +510,9 @@ static int sbgemv_kernel_16x4(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
__m512i M512_EPI32_1 = _mm512_set1_epi32(1); |
|
|
|
@@ -591,7 +604,9 @@ static int sbgemv_kernel_30x5(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
__m512 result_0, result_1; |
|
|
|
@@ -782,7 +797,9 @@ static int sbgemv_kernel_16x6(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
__m512i M512_EPI32_1 = _mm512_set1_epi32(1); |
|
|
|
@@ -866,9 +883,9 @@ static int sbgemv_kernel_16x6(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, |
|
|
|
|
|
|
|
result256_0 = _mm256_setzero_ps(); |
|
|
|
|
|
|
|
matrixArray_0 = _mm256_loadu_si256(&a[(tag_m_16x)*6]); // Load 2 rows with n=6 plus 4 element |
|
|
|
matrixArray_1 = _mm256_loadu_si256(&a[((tag_m_16x+2)*6 + 4)]); // Load 2 rows with n=6 plus 4 element |
|
|
|
matrixArray_2 = _mm256_loadu_si256(&a[((tag_m_16x+5)*6 + 2)]); // Load 2 rows with n=6 plus 4 element |
|
|
|
matrixArray_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_16x)*6]); // Load 2 rows with n=6 plus 4 element |
|
|
|
matrixArray_1 = _mm256_loadu_si256((__m256i *)&a[((tag_m_16x+2)*6 + 4)]); // Load 2 rows with n=6 plus 4 element |
|
|
|
matrixArray_2 = _mm256_loadu_si256((__m256i *)&a[((tag_m_16x+5)*6 + 2)]); // Load 2 rows with n=6 plus 4 element |
|
|
|
|
|
|
|
// Process the 0|1 elements |
|
|
|
// Select the 0|1 elements for each row |
|
|
|
@@ -957,7 +974,9 @@ static int sbgemv_kernel_16x7(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
__m512i M512_EPI32_2 = _mm512_set1_epi32(2); |
|
|
|
@@ -1110,7 +1129,7 @@ static int sbgemv_kernel_16x8(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, |
|
|
|
{ |
|
|
|
BLASLONG tag_m_16x = m & (~15); |
|
|
|
|
|
|
|
__m128i x128 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7| |
|
|
|
__m128i x128 = _mm_loadu_si128((__m128i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7| |
|
|
|
|
|
|
|
if (tag_m_16x > 0) { |
|
|
|
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3; |
|
|
|
@@ -1122,7 +1141,9 @@ static int sbgemv_kernel_16x8(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
__m512i M512_EPI32_2 = _mm512_set1_epi32(2); |
|
|
|
@@ -1214,7 +1235,7 @@ static int sbgemv_kernel_16x8(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, |
|
|
|
__m128 result128, tmp128; |
|
|
|
for (BLASLONG i = tag_m_16x; i < m; i++) { |
|
|
|
result128 = _mm_setzero_ps(); |
|
|
|
matrixArray128 = _mm_loadu_si128(&a[(i)*8]); // Load 1 rows with n=8 |
|
|
|
matrixArray128 = _mm_loadu_si128((__m128i *)&a[(i)*8]); // Load 1 rows with n=8 |
|
|
|
result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128); |
|
|
|
tmp128 = _mm_shuffle_ps(result128, result128, 14); |
|
|
|
result128 = _mm_add_ps(result128, tmp128); |
|
|
|
@@ -1258,7 +1279,7 @@ static int sbgemv_kernel_14x9(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, |
|
|
|
|
|
|
|
unsigned char x_load_mask_value = (((unsigned char)0xff) >> 7); |
|
|
|
__mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value); |
|
|
|
__m128i x128_0 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7| |
|
|
|
__m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7| |
|
|
|
__m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|0 |0 | 0| 0| 0| 0| 0| |
|
|
|
|
|
|
|
if (tag_m_14x > 0) { |
|
|
|
@@ -1271,7 +1292,9 @@ static int sbgemv_kernel_14x9(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
__m256i M256_EPI16_2 = _mm256_set1_epi16(2); |
|
|
|
@@ -1390,7 +1413,7 @@ static int sbgemv_kernel_12x10(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x |
|
|
|
|
|
|
|
unsigned char x_load_mask_value = (((unsigned char)0xf) >> 3); |
|
|
|
__mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value); |
|
|
|
__m128i x128_0 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7| |
|
|
|
__m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7| |
|
|
|
__m128i x128_1 = _mm_maskz_loadu_epi32(x_load_mask, (x+8)); // |x8|x9|0 | 0| 0| 0| 0| 0| |
|
|
|
|
|
|
|
if (tag_m_12x > 0) { |
|
|
|
@@ -1403,7 +1426,9 @@ static int sbgemv_kernel_12x10(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
__m256i M256_EPI32_1 = _mm256_set1_epi32(1); |
|
|
|
@@ -1522,7 +1547,7 @@ static int sbgemv_kernel_15x11(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x |
|
|
|
|
|
|
|
unsigned char x_load_mask_value = (((unsigned char)0xff) >> 5); |
|
|
|
__mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value); |
|
|
|
__m128i x128_0 = _mm_loadu_si128(x); // |x0|x1| x2|x3|x4|x5|x6|x7| |
|
|
|
__m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1| x2|x3|x4|x5|x6|x7| |
|
|
|
__m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|x9|x10| 0| 0| 0| 0| 0| |
|
|
|
|
|
|
|
if (tag_m_15x > 0) { |
|
|
|
@@ -1535,7 +1560,9 @@ static int sbgemv_kernel_15x11(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
__m512i idx_stage1_base_0, idx_stage1_base_1, idx_stage1_base_2, idx_stage1_base_3, idx_stage1_base_4, idx_stage1_base_5; |
|
|
|
@@ -1690,7 +1717,7 @@ static int sbgemv_kernel_15x12(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x |
|
|
|
|
|
|
|
unsigned char x_load_mask_value = (((unsigned char)0xff) >> 4); |
|
|
|
__mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value); |
|
|
|
__m128i x128_0 = _mm_loadu_si128(x); // |x0|x1| x2| x3|x4|x5|x6|x7| |
|
|
|
__m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1| x2| x3|x4|x5|x6|x7| |
|
|
|
__m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|x9|x10|x11| 0| 0| 0| 0| |
|
|
|
|
|
|
|
if (tag_m_15x > 0) { |
|
|
|
@@ -1703,7 +1730,9 @@ static int sbgemv_kernel_15x12(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
__m512i idx_stage1_base_0, idx_stage1_base_1, idx_stage1_base_2, idx_stage1_base_3, idx_stage1_base_4, idx_stage1_base_5; |
|
|
|
@@ -1873,16 +1902,15 @@ static int sbgemv_kernel_16x13(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
__m512i M512_EPI32_4 = _mm512_set1_epi32(4); |
|
|
|
__m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0); |
|
|
|
__m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4); |
|
|
|
|
|
|
|
unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 6); |
|
|
|
__mmask32 load_mask = *((__mmask32*) &load_mask_value); |
|
|
|
|
|
|
|
// Prepare X with 2-step interleave way |
|
|
|
xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1); |
|
|
|
BF16_INTERLEAVE_1x32(xArray) |
|
|
|
@@ -2045,7 +2073,9 @@ static int sbgemv_kernel_16x14(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
__m512i M512_EPI32_4 = _mm512_set1_epi32(4); |
|
|
|
@@ -2207,16 +2237,15 @@ static int sbgemv_kernel_16x15(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
__m512i M512_EPI32_4 = _mm512_set1_epi32(4); |
|
|
|
__m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0); |
|
|
|
__m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4); |
|
|
|
|
|
|
|
unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 2); |
|
|
|
__mmask32 load_mask = *((__mmask32*) &load_mask_value); |
|
|
|
|
|
|
|
// Prepare X with 2-step interleave way |
|
|
|
xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1); |
|
|
|
BF16_INTERLEAVE_1x32(xArray) |
|
|
|
@@ -2364,7 +2393,7 @@ static int sbgemv_kernel_16x16(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x |
|
|
|
{ |
|
|
|
BLASLONG tag_m_16x = m & (~15); |
|
|
|
|
|
|
|
__m256i x256 = _mm256_loadu_si256(x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|x15| |
|
|
|
__m256i x256 = _mm256_loadu_si256((__m256i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|x15| |
|
|
|
|
|
|
|
if (tag_m_16x > 0) { |
|
|
|
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \ |
|
|
|
@@ -2377,7 +2406,9 @@ static int sbgemv_kernel_16x16(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
__m512i M512_EPI32_4 = _mm512_set1_epi32(4); |
|
|
|
@@ -2484,7 +2515,7 @@ static int sbgemv_kernel_16x16(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x |
|
|
|
__m128 accum128, tmp128; |
|
|
|
for (BLASLONG i = tag_m_16x; i < m; i++) { |
|
|
|
accum256 = _mm256_setzero_ps(); |
|
|
|
matrixArray256 = _mm256_loadu_si256(&a[(i)*16]); // Load 1 rows with n=16 |
|
|
|
matrixArray256 = _mm256_loadu_si256((__m256i *)&a[(i)*16]); // Load 1 rows with n=16 |
|
|
|
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256); |
|
|
|
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1)); |
|
|
|
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e); |
|
|
|
@@ -2535,7 +2566,9 @@ static int sbgemv_kernel_8x16p_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16 |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \ |
|
|
|
@@ -2647,8 +2680,6 @@ static int sbgemv_kernel_1x128_lda_direct(BLASLONG m, BLASLONG n, float alpha, b |
|
|
|
BLASLONG tag_n_32x = n & (~31); |
|
|
|
BLASLONG tag_n_128x = n & (~127); |
|
|
|
|
|
|
|
__m512 accum512_0, accum512_1, accum512_2, accum512_3, accum512_4, accum512_5, accum512_6, accum512_7, \ |
|
|
|
accum512_8, accum512_9, accum512_10, accum512_11, accum512_12, accum512_13, accum512_14, accum512_15; |
|
|
|
__m512 accum512_bridge[8]; |
|
|
|
__m512 accum512_t_0, accum512_t_1, accum512_t_2, accum512_t_3; |
|
|
|
__m256 accum256_0; |
|
|
|
@@ -2658,7 +2689,9 @@ static int sbgemv_kernel_1x128_lda_direct(BLASLONG m, BLASLONG n, float alpha, b |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3; |
|
|
|
@@ -2825,7 +2858,9 @@ static int sbgemv_kernel_8x32_lda_direct(BLASLONG m, BLASLONG n, float alpha, bf |
|
|
|
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_set1_ps(beta); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7; |
|
|
|
@@ -2961,7 +2996,9 @@ static int sbgemv_kernel_8x16m_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16 |
|
|
|
__m512 ALPHAVECTOR = _mm512_castps256_ps512(_mm256_set1_ps(alpha)); |
|
|
|
#endif |
|
|
|
#ifndef ZERO_BETA |
|
|
|
#ifndef ONE_BETA |
|
|
|
__m512 BETAVECTOR = _mm512_castps256_ps512(_mm256_set1_ps(beta)); |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
__m256 accum256_0, accum256_1, accum256_2, accum256_3, accum256_4, accum256_5, accum256_6, accum256_7, \ |
|
|
|
@@ -3012,7 +3049,7 @@ static int sbgemv_kernel_8x16m_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16 |
|
|
|
__m128 accum128, tmp128; |
|
|
|
for (BLASLONG i = tag_m_8x; i < m; i++) { |
|
|
|
accum256_0 = _mm256_setzero_ps(); |
|
|
|
matrixArray_0 = _mm256_loadu_si256(&a[(i)*lda]); // Load 1 rows with n=16 |
|
|
|
matrixArray_0 = _mm256_loadu_si256((__m256i *)&a[(i)*lda]); // Load 1 rows with n=16 |
|
|
|
accum256_0 = _mm256_dpbf16_ps(accum256_0, (__m256bh) matrixArray_0, (__m256bh) xArray256); |
|
|
|
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1)); |
|
|
|
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e); |
|
|
|
|