|
|
|
@@ -33,16 +33,6 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
|
|
|
#include "common.h" |
|
|
|
#include <arm_neon.h> |
|
|
|
|
|
|
|
#if (defined(__GNUC__) && __GNUC__ >= 13) |
|
|
|
#define BF16_TO_FP32(bf16) ((float)(bf16)) |
|
|
|
#else |
|
|
|
static inline float bf16_to_fp32(bfloat16_t bf16) { |
|
|
|
uint32_t fp32 = (uint32_t)(*((u_int16_t*)(&bf16))) << 16; |
|
|
|
return *((float*)&fp32); |
|
|
|
} |
|
|
|
#define BF16_TO_FP32(bf16) bf16_to_fp32(bf16) |
|
|
|
#endif |
|
|
|
|
|
|
|
static void beta_op(float *x, BLASLONG n, FLOAT beta) { |
|
|
|
if (beta == 0) { |
|
|
|
memset(x, 0, n * sizeof(float)); |
|
|
|
@@ -268,24 +258,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, |
|
|
|
} |
|
|
|
|
|
|
|
if (rest_m) { |
|
|
|
x0 = alpha * BF16_TO_FP32(x_ptr[0]); |
|
|
|
x1 = alpha * BF16_TO_FP32(x_ptr[1]); |
|
|
|
x2 = alpha * BF16_TO_FP32(x_ptr[2]); |
|
|
|
x3 = alpha * BF16_TO_FP32(x_ptr[3]); |
|
|
|
x4 = alpha * BF16_TO_FP32(x_ptr[4]); |
|
|
|
x5 = alpha * BF16_TO_FP32(x_ptr[5]); |
|
|
|
x6 = alpha * BF16_TO_FP32(x_ptr[6]); |
|
|
|
x7 = alpha * BF16_TO_FP32(x_ptr[7]); |
|
|
|
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]); |
|
|
|
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]); |
|
|
|
x2 = alpha * vcvtah_f32_bf16(x_ptr[2]); |
|
|
|
x3 = alpha * vcvtah_f32_bf16(x_ptr[3]); |
|
|
|
x4 = alpha * vcvtah_f32_bf16(x_ptr[4]); |
|
|
|
x5 = alpha * vcvtah_f32_bf16(x_ptr[5]); |
|
|
|
x6 = alpha * vcvtah_f32_bf16(x_ptr[6]); |
|
|
|
x7 = alpha * vcvtah_f32_bf16(x_ptr[7]); |
|
|
|
|
|
|
|
for (BLASLONG j = 0; j < rest_m; j++) { |
|
|
|
y_ptr[j] += x0 * BF16_TO_FP32(a_ptr0[j]); |
|
|
|
y_ptr[j] += x1 * BF16_TO_FP32(a_ptr1[j]); |
|
|
|
y_ptr[j] += x2 * BF16_TO_FP32(a_ptr2[j]); |
|
|
|
y_ptr[j] += x3 * BF16_TO_FP32(a_ptr3[j]); |
|
|
|
y_ptr[j] += x4 * BF16_TO_FP32(a_ptr4[j]); |
|
|
|
y_ptr[j] += x5 * BF16_TO_FP32(a_ptr5[j]); |
|
|
|
y_ptr[j] += x6 * BF16_TO_FP32(a_ptr6[j]); |
|
|
|
y_ptr[j] += x7 * BF16_TO_FP32(a_ptr7[j]); |
|
|
|
y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]); |
|
|
|
y_ptr[j] += x1 * vcvtah_f32_bf16(a_ptr1[j]); |
|
|
|
y_ptr[j] += x2 * vcvtah_f32_bf16(a_ptr2[j]); |
|
|
|
y_ptr[j] += x3 * vcvtah_f32_bf16(a_ptr3[j]); |
|
|
|
y_ptr[j] += x4 * vcvtah_f32_bf16(a_ptr4[j]); |
|
|
|
y_ptr[j] += x5 * vcvtah_f32_bf16(a_ptr5[j]); |
|
|
|
y_ptr[j] += x6 * vcvtah_f32_bf16(a_ptr6[j]); |
|
|
|
y_ptr[j] += x7 * vcvtah_f32_bf16(a_ptr7[j]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -384,16 +374,16 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, |
|
|
|
} |
|
|
|
|
|
|
|
if (rest_m) { |
|
|
|
x0 = alpha * BF16_TO_FP32(x_ptr[0]); |
|
|
|
x1 = alpha * BF16_TO_FP32(x_ptr[1]); |
|
|
|
x2 = alpha * BF16_TO_FP32(x_ptr[2]); |
|
|
|
x3 = alpha * BF16_TO_FP32(x_ptr[3]); |
|
|
|
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]); |
|
|
|
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]); |
|
|
|
x2 = alpha * vcvtah_f32_bf16(x_ptr[2]); |
|
|
|
x3 = alpha * vcvtah_f32_bf16(x_ptr[3]); |
|
|
|
|
|
|
|
for (BLASLONG j = 0; j < rest_m; j++) { |
|
|
|
y_ptr[j] += x0 * BF16_TO_FP32(a_ptr0[j]); |
|
|
|
y_ptr[j] += x1 * BF16_TO_FP32(a_ptr1[j]); |
|
|
|
y_ptr[j] += x2 * BF16_TO_FP32(a_ptr2[j]); |
|
|
|
y_ptr[j] += x3 * BF16_TO_FP32(a_ptr3[j]); |
|
|
|
y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]); |
|
|
|
y_ptr[j] += x1 * vcvtah_f32_bf16(a_ptr1[j]); |
|
|
|
y_ptr[j] += x2 * vcvtah_f32_bf16(a_ptr2[j]); |
|
|
|
y_ptr[j] += x3 * vcvtah_f32_bf16(a_ptr3[j]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -480,13 +470,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, |
|
|
|
} |
|
|
|
|
|
|
|
if (m & 2) { |
|
|
|
x0 = alpha * (BF16_TO_FP32(x_ptr[0])); |
|
|
|
x1 = alpha * (BF16_TO_FP32(x_ptr[1])); |
|
|
|
x0 = alpha * (vcvtah_f32_bf16(x_ptr[0])); |
|
|
|
x1 = alpha * (vcvtah_f32_bf16(x_ptr[1])); |
|
|
|
|
|
|
|
y_ptr[0] += x0 * BF16_TO_FP32(a_ptr0[0]); |
|
|
|
y_ptr[0] += x1 * BF16_TO_FP32(a_ptr1[0]); |
|
|
|
y_ptr[1] += x0 * BF16_TO_FP32(a_ptr0[1]); |
|
|
|
y_ptr[1] += x1 * BF16_TO_FP32(a_ptr1[1]); |
|
|
|
y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]); |
|
|
|
y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]); |
|
|
|
y_ptr[1] += x0 * vcvtah_f32_bf16(a_ptr0[1]); |
|
|
|
y_ptr[1] += x1 * vcvtah_f32_bf16(a_ptr1[1]); |
|
|
|
|
|
|
|
a_ptr0 += 2; |
|
|
|
a_ptr1 += 2; |
|
|
|
@@ -495,23 +485,23 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, |
|
|
|
} |
|
|
|
|
|
|
|
if (m & 1) { |
|
|
|
x0 = alpha * BF16_TO_FP32(x_ptr[0]); |
|
|
|
x1 = alpha * BF16_TO_FP32(x_ptr[1]); |
|
|
|
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]); |
|
|
|
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]); |
|
|
|
|
|
|
|
y_ptr[0] += x0 * BF16_TO_FP32(a_ptr0[0]); |
|
|
|
y_ptr[0] += x1 * BF16_TO_FP32(a_ptr1[0]); |
|
|
|
y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]); |
|
|
|
y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]); |
|
|
|
} |
|
|
|
|
|
|
|
x_ptr += 2; |
|
|
|
} |
|
|
|
|
|
|
|
if (n & 1) { |
|
|
|
x0 = BF16_TO_FP32(x_ptr[0]) * alpha; |
|
|
|
x0 = vcvtah_f32_bf16(x_ptr[0]) * alpha; |
|
|
|
y_ptr = y; |
|
|
|
a_ptr0 = a_ptr; |
|
|
|
|
|
|
|
for (j = 0; j < m; j++) { |
|
|
|
y_ptr[j] += x0 * BF16_TO_FP32(a_ptr0[j]); |
|
|
|
y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -525,10 +515,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, |
|
|
|
} |
|
|
|
|
|
|
|
for (j = 0; j < n; j++) { |
|
|
|
x0 = alpha * BF16_TO_FP32(*x_ptr); |
|
|
|
x0 = alpha * vcvtah_f32_bf16(*x_ptr); |
|
|
|
iy = 0; |
|
|
|
for (i = 0; i < m; i++) { |
|
|
|
y[iy] += x0 * BF16_TO_FP32(a_ptr[i]); |
|
|
|
y[iy] += x0 * vcvtah_f32_bf16(a_ptr[i]); |
|
|
|
iy += incy; |
|
|
|
} |
|
|
|
|
|
|
|
|