Browse Source

Merge pull request #5181 from taoye9/change_sbgemn_cast_bf16

replace customize bf16_to_fp32 with arm neon vcvtah_f32_bf16
tags/v0.3.30
Martin Kroeker GitHub 1 year ago
parent
commit
2f778554b8
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
1 changed files with 38 additions and 48 deletions
  1. +38
    -48
      kernel/arm64/sbgemv_n_neon.c

+ 38
- 48
kernel/arm64/sbgemv_n_neon.c View File

@@ -33,16 +33,6 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "common.h" #include "common.h"
#include <arm_neon.h> #include <arm_neon.h>


#if (defined(__GNUC__) && __GNUC__ >= 13)
#define BF16_TO_FP32(bf16) ((float)(bf16))
#else
static inline float bf16_to_fp32(bfloat16_t bf16) {
uint32_t fp32 = (uint32_t)(*((u_int16_t*)(&bf16))) << 16;
return *((float*)&fp32);
}
#define BF16_TO_FP32(bf16) bf16_to_fp32(bf16)
#endif

static void beta_op(float *x, BLASLONG n, FLOAT beta) { static void beta_op(float *x, BLASLONG n, FLOAT beta) {
if (beta == 0) { if (beta == 0) {
memset(x, 0, n * sizeof(float)); memset(x, 0, n * sizeof(float));
@@ -268,24 +258,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
} }


if (rest_m) { if (rest_m) {
x0 = alpha * BF16_TO_FP32(x_ptr[0]);
x1 = alpha * BF16_TO_FP32(x_ptr[1]);
x2 = alpha * BF16_TO_FP32(x_ptr[2]);
x3 = alpha * BF16_TO_FP32(x_ptr[3]);
x4 = alpha * BF16_TO_FP32(x_ptr[4]);
x5 = alpha * BF16_TO_FP32(x_ptr[5]);
x6 = alpha * BF16_TO_FP32(x_ptr[6]);
x7 = alpha * BF16_TO_FP32(x_ptr[7]);
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]);
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]);
x2 = alpha * vcvtah_f32_bf16(x_ptr[2]);
x3 = alpha * vcvtah_f32_bf16(x_ptr[3]);
x4 = alpha * vcvtah_f32_bf16(x_ptr[4]);
x5 = alpha * vcvtah_f32_bf16(x_ptr[5]);
x6 = alpha * vcvtah_f32_bf16(x_ptr[6]);
x7 = alpha * vcvtah_f32_bf16(x_ptr[7]);


for (BLASLONG j = 0; j < rest_m; j++) { for (BLASLONG j = 0; j < rest_m; j++) {
y_ptr[j] += x0 * BF16_TO_FP32(a_ptr0[j]);
y_ptr[j] += x1 * BF16_TO_FP32(a_ptr1[j]);
y_ptr[j] += x2 * BF16_TO_FP32(a_ptr2[j]);
y_ptr[j] += x3 * BF16_TO_FP32(a_ptr3[j]);
y_ptr[j] += x4 * BF16_TO_FP32(a_ptr4[j]);
y_ptr[j] += x5 * BF16_TO_FP32(a_ptr5[j]);
y_ptr[j] += x6 * BF16_TO_FP32(a_ptr6[j]);
y_ptr[j] += x7 * BF16_TO_FP32(a_ptr7[j]);
y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]);
y_ptr[j] += x1 * vcvtah_f32_bf16(a_ptr1[j]);
y_ptr[j] += x2 * vcvtah_f32_bf16(a_ptr2[j]);
y_ptr[j] += x3 * vcvtah_f32_bf16(a_ptr3[j]);
y_ptr[j] += x4 * vcvtah_f32_bf16(a_ptr4[j]);
y_ptr[j] += x5 * vcvtah_f32_bf16(a_ptr5[j]);
y_ptr[j] += x6 * vcvtah_f32_bf16(a_ptr6[j]);
y_ptr[j] += x7 * vcvtah_f32_bf16(a_ptr7[j]);
} }
} }


@@ -384,16 +374,16 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
} }


if (rest_m) { if (rest_m) {
x0 = alpha * BF16_TO_FP32(x_ptr[0]);
x1 = alpha * BF16_TO_FP32(x_ptr[1]);
x2 = alpha * BF16_TO_FP32(x_ptr[2]);
x3 = alpha * BF16_TO_FP32(x_ptr[3]);
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]);
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]);
x2 = alpha * vcvtah_f32_bf16(x_ptr[2]);
x3 = alpha * vcvtah_f32_bf16(x_ptr[3]);


for (BLASLONG j = 0; j < rest_m; j++) { for (BLASLONG j = 0; j < rest_m; j++) {
y_ptr[j] += x0 * BF16_TO_FP32(a_ptr0[j]);
y_ptr[j] += x1 * BF16_TO_FP32(a_ptr1[j]);
y_ptr[j] += x2 * BF16_TO_FP32(a_ptr2[j]);
y_ptr[j] += x3 * BF16_TO_FP32(a_ptr3[j]);
y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]);
y_ptr[j] += x1 * vcvtah_f32_bf16(a_ptr1[j]);
y_ptr[j] += x2 * vcvtah_f32_bf16(a_ptr2[j]);
y_ptr[j] += x3 * vcvtah_f32_bf16(a_ptr3[j]);
} }
} }


@@ -480,13 +470,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
} }


if (m & 2) { if (m & 2) {
x0 = alpha * (BF16_TO_FP32(x_ptr[0]));
x1 = alpha * (BF16_TO_FP32(x_ptr[1]));
x0 = alpha * (vcvtah_f32_bf16(x_ptr[0]));
x1 = alpha * (vcvtah_f32_bf16(x_ptr[1]));


y_ptr[0] += x0 * BF16_TO_FP32(a_ptr0[0]);
y_ptr[0] += x1 * BF16_TO_FP32(a_ptr1[0]);
y_ptr[1] += x0 * BF16_TO_FP32(a_ptr0[1]);
y_ptr[1] += x1 * BF16_TO_FP32(a_ptr1[1]);
y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]);
y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]);
y_ptr[1] += x0 * vcvtah_f32_bf16(a_ptr0[1]);
y_ptr[1] += x1 * vcvtah_f32_bf16(a_ptr1[1]);


a_ptr0 += 2; a_ptr0 += 2;
a_ptr1 += 2; a_ptr1 += 2;
@@ -495,23 +485,23 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
} }


if (m & 1) { if (m & 1) {
x0 = alpha * BF16_TO_FP32(x_ptr[0]);
x1 = alpha * BF16_TO_FP32(x_ptr[1]);
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]);
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]);


y_ptr[0] += x0 * BF16_TO_FP32(a_ptr0[0]);
y_ptr[0] += x1 * BF16_TO_FP32(a_ptr1[0]);
y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]);
y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]);
} }


x_ptr += 2; x_ptr += 2;
} }


if (n & 1) { if (n & 1) {
x0 = BF16_TO_FP32(x_ptr[0]) * alpha;
x0 = vcvtah_f32_bf16(x_ptr[0]) * alpha;
y_ptr = y; y_ptr = y;
a_ptr0 = a_ptr; a_ptr0 = a_ptr;


for (j = 0; j < m; j++) { for (j = 0; j < m; j++) {
y_ptr[j] += x0 * BF16_TO_FP32(a_ptr0[j]);
y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]);
} }
} }


@@ -525,10 +515,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
} }


for (j = 0; j < n; j++) { for (j = 0; j < n; j++) {
x0 = alpha * BF16_TO_FP32(*x_ptr);
x0 = alpha * vcvtah_f32_bf16(*x_ptr);
iy = 0; iy = 0;
for (i = 0; i < m; i++) { for (i = 0; i < m; i++) {
y[iy] += x0 * BF16_TO_FP32(a_ptr[i]);
y[iy] += x0 * vcvtah_f32_bf16(a_ptr[i]);
iy += incy; iy += incy;
} }




Loading…
Cancel
Save