From 4c00099ed65af573912065a69d83ce42a9aa0cba Mon Sep 17 00:00:00 2001 From: Ye Tao Date: Wed, 12 Mar 2025 16:20:15 +0000 Subject: [PATCH] replace customize bf16_to_fp32 with arm neon vcvtah_f32_bf16 --- kernel/arm64/sbgemv_n_neon.c | 86 ++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 48 deletions(-) diff --git a/kernel/arm64/sbgemv_n_neon.c b/kernel/arm64/sbgemv_n_neon.c index 9e7ea03c8..489d4d22c 100644 --- a/kernel/arm64/sbgemv_n_neon.c +++ b/kernel/arm64/sbgemv_n_neon.c @@ -33,16 +33,6 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" #include -#if (defined(__GNUC__) && __GNUC__ >= 13) -#define BF16_TO_FP32(bf16) ((float)(bf16)) -#else -static inline float bf16_to_fp32(bfloat16_t bf16) { - uint32_t fp32 = (uint32_t)(*((u_int16_t*)(&bf16))) << 16; - return *((float*)&fp32); -} -#define BF16_TO_FP32(bf16) bf16_to_fp32(bf16) -#endif - static void beta_op(float *x, BLASLONG n, FLOAT beta) { if (beta == 0) { memset(x, 0, n * sizeof(float)); @@ -268,24 +258,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, } if (rest_m) { - x0 = alpha * BF16_TO_FP32(x_ptr[0]); - x1 = alpha * BF16_TO_FP32(x_ptr[1]); - x2 = alpha * BF16_TO_FP32(x_ptr[2]); - x3 = alpha * BF16_TO_FP32(x_ptr[3]); - x4 = alpha * BF16_TO_FP32(x_ptr[4]); - x5 = alpha * BF16_TO_FP32(x_ptr[5]); - x6 = alpha * BF16_TO_FP32(x_ptr[6]); - x7 = alpha * BF16_TO_FP32(x_ptr[7]); + x0 = alpha * vcvtah_f32_bf16(x_ptr[0]); + x1 = alpha * vcvtah_f32_bf16(x_ptr[1]); + x2 = alpha * vcvtah_f32_bf16(x_ptr[2]); + x3 = alpha * vcvtah_f32_bf16(x_ptr[3]); + x4 = alpha * vcvtah_f32_bf16(x_ptr[4]); + x5 = alpha * vcvtah_f32_bf16(x_ptr[5]); + x6 = alpha * vcvtah_f32_bf16(x_ptr[6]); + x7 = alpha * vcvtah_f32_bf16(x_ptr[7]); for (BLASLONG j = 0; j < rest_m; j++) { - y_ptr[j] += x0 * BF16_TO_FP32(a_ptr0[j]); - y_ptr[j] += x1 * BF16_TO_FP32(a_ptr1[j]); - y_ptr[j] += x2 * BF16_TO_FP32(a_ptr2[j]); - y_ptr[j] += x3 * BF16_TO_FP32(a_ptr3[j]); - y_ptr[j] += x4 * BF16_TO_FP32(a_ptr4[j]); - y_ptr[j] += x5 * BF16_TO_FP32(a_ptr5[j]); - y_ptr[j] += x6 * BF16_TO_FP32(a_ptr6[j]); - y_ptr[j] += x7 * BF16_TO_FP32(a_ptr7[j]); + y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]); + y_ptr[j] += x1 * vcvtah_f32_bf16(a_ptr1[j]); + y_ptr[j] += x2 * vcvtah_f32_bf16(a_ptr2[j]); + y_ptr[j] += x3 * vcvtah_f32_bf16(a_ptr3[j]); + y_ptr[j] += x4 * vcvtah_f32_bf16(a_ptr4[j]); + y_ptr[j] += x5 * vcvtah_f32_bf16(a_ptr5[j]); + y_ptr[j] += x6 * vcvtah_f32_bf16(a_ptr6[j]); + y_ptr[j] += x7 * vcvtah_f32_bf16(a_ptr7[j]); } } @@ -384,16 +374,16 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, } if (rest_m) { - x0 = alpha * BF16_TO_FP32(x_ptr[0]); - x1 = alpha * BF16_TO_FP32(x_ptr[1]); - x2 = alpha * BF16_TO_FP32(x_ptr[2]); - x3 = alpha * BF16_TO_FP32(x_ptr[3]); + x0 = alpha * vcvtah_f32_bf16(x_ptr[0]); + x1 = alpha * vcvtah_f32_bf16(x_ptr[1]); + x2 = alpha * vcvtah_f32_bf16(x_ptr[2]); + x3 = alpha * vcvtah_f32_bf16(x_ptr[3]); for (BLASLONG j = 0; j < rest_m; j++) { - y_ptr[j] += x0 * BF16_TO_FP32(a_ptr0[j]); - y_ptr[j] += x1 * BF16_TO_FP32(a_ptr1[j]); - y_ptr[j] += x2 * BF16_TO_FP32(a_ptr2[j]); - y_ptr[j] += x3 * BF16_TO_FP32(a_ptr3[j]); + y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]); + y_ptr[j] += x1 * vcvtah_f32_bf16(a_ptr1[j]); + y_ptr[j] += x2 * vcvtah_f32_bf16(a_ptr2[j]); + y_ptr[j] += x3 * vcvtah_f32_bf16(a_ptr3[j]); } } @@ -480,13 +470,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, } if (m & 2) { - x0 = alpha * (BF16_TO_FP32(x_ptr[0])); - x1 = alpha * (BF16_TO_FP32(x_ptr[1])); + x0 = alpha * (vcvtah_f32_bf16(x_ptr[0])); + x1 = alpha * (vcvtah_f32_bf16(x_ptr[1])); - y_ptr[0] += x0 * BF16_TO_FP32(a_ptr0[0]); - y_ptr[0] += x1 * BF16_TO_FP32(a_ptr1[0]); - y_ptr[1] += x0 * BF16_TO_FP32(a_ptr0[1]); - y_ptr[1] += x1 * BF16_TO_FP32(a_ptr1[1]); + y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]); + y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]); + y_ptr[1] += x0 * vcvtah_f32_bf16(a_ptr0[1]); + y_ptr[1] += x1 * vcvtah_f32_bf16(a_ptr1[1]); a_ptr0 += 2; a_ptr1 += 2; @@ -495,23 +485,23 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, } if (m & 1) { - x0 = alpha * BF16_TO_FP32(x_ptr[0]); - x1 = alpha * BF16_TO_FP32(x_ptr[1]); + x0 = alpha * vcvtah_f32_bf16(x_ptr[0]); + x1 = alpha * vcvtah_f32_bf16(x_ptr[1]); - y_ptr[0] += x0 * BF16_TO_FP32(a_ptr0[0]); - y_ptr[0] += x1 * BF16_TO_FP32(a_ptr1[0]); + y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]); + y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]); } x_ptr += 2; } if (n & 1) { - x0 = BF16_TO_FP32(x_ptr[0]) * alpha; + x0 = vcvtah_f32_bf16(x_ptr[0]) * alpha; y_ptr = y; a_ptr0 = a_ptr; for (j = 0; j < m; j++) { - y_ptr[j] += x0 * BF16_TO_FP32(a_ptr0[j]); + y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]); } } @@ -525,10 +515,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, } for (j = 0; j < n; j++) { - x0 = alpha * BF16_TO_FP32(*x_ptr); + x0 = alpha * vcvtah_f32_bf16(*x_ptr); iy = 0; for (i = 0; i < m; i++) { - y[iy] += x0 * BF16_TO_FP32(a_ptr[i]); + y[iy] += x0 * vcvtah_f32_bf16(a_ptr[i]); iy += incy; }