|
|
|
@@ -0,0 +1,542 @@ |
|
|
|
/*************************************************************************** |
|
|
|
Copyright (c) 2025, The OpenBLAS Project |
|
|
|
All rights reserved. |
|
|
|
|
|
|
|
Redistribution and use in source and binary forms, with or without |
|
|
|
modification, are permitted provided that the following conditions are |
|
|
|
met: |
|
|
|
|
|
|
|
1. Redistributions of source code must retain the above copyright |
|
|
|
notice, this list of conditions and the following disclaimer. |
|
|
|
|
|
|
|
2. Redistributions in binary form must reproduce the above copyright |
|
|
|
notice, this list of conditions and the following disclaimer in |
|
|
|
the documentation and/or other materials provided with the |
|
|
|
distribution. |
|
|
|
3. Neither the name of the OpenBLAS project nor the names of |
|
|
|
its contributors may be used to endorse or promote products |
|
|
|
derived from this software without specific prior written |
|
|
|
permission. |
|
|
|
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
|
|
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
|
|
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE |
|
|
|
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE |
|
|
|
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
|
|
|
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE |
|
|
|
GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) |
|
|
|
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT |
|
|
|
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF |
|
|
|
THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
|
|
|
*****************************************************************************/ |
|
|
|
|
|
|
|
#include "common.h" |
|
|
|
#include <arm_neon.h> |
|
|
|
|
|
|
|
#if (defined(__GNUC__) && __GNUC__ >= 13) |
|
|
|
#define BF16_TO_FP32(bf16) ((float)(bf16)) |
|
|
|
#else |
|
|
|
static inline float bf16_to_fp32(bfloat16_t bf16) { |
|
|
|
uint32_t fp32 = (uint32_t)(*((u_int16_t*)(&bf16))) << 16; |
|
|
|
return *((float*)&fp32); |
|
|
|
} |
|
|
|
#define BF16_TO_FP32(bf16) bf16_to_fp32(bf16) |
|
|
|
#endif |
|
|
|
|
|
|
|
static void beta_op(float *x, BLASLONG n, FLOAT beta) { |
|
|
|
if (beta == 0) { |
|
|
|
memset(x, 0, n * sizeof(float)); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
float32x4_t y0, y1, y2, y3; |
|
|
|
|
|
|
|
for (BLASLONG i = 0; i < n / 16; i++) { |
|
|
|
y0 = vld1q_f32(x); |
|
|
|
y1 = vld1q_f32(x + 4); |
|
|
|
y2 = vld1q_f32(x + 8); |
|
|
|
y3 = vld1q_f32(x + 12); |
|
|
|
|
|
|
|
y0 = vmulq_n_f32(y0, beta); |
|
|
|
y1 = vmulq_n_f32(y1, beta); |
|
|
|
y2 = vmulq_n_f32(y2, beta); |
|
|
|
y3 = vmulq_n_f32(y3, beta); |
|
|
|
|
|
|
|
vst1q_f32(x, y0); |
|
|
|
vst1q_f32(x + 4, y1); |
|
|
|
vst1q_f32(x + 8, y2); |
|
|
|
vst1q_f32(x + 12, y3); |
|
|
|
|
|
|
|
x += 16; |
|
|
|
} |
|
|
|
|
|
|
|
if (n & 15) { |
|
|
|
BLASLONG rest_n = n & 15; |
|
|
|
for (BLASLONG i = 0; i < (rest_n) / 4; i++) { |
|
|
|
y0 = vld1q_f32(x); |
|
|
|
y0 = vmulq_n_f32(y0, beta); |
|
|
|
vst1q_f32(x, y0); |
|
|
|
x += 4; |
|
|
|
} |
|
|
|
|
|
|
|
if (rest_n & 3) { |
|
|
|
x[0] *= beta; |
|
|
|
if ((rest_n & 3) > 1) |
|
|
|
x[1] *= beta; |
|
|
|
if ((rest_n & 3) > 2) |
|
|
|
x[2] *= beta; |
|
|
|
} |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, |
|
|
|
bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy) { |
|
|
|
BLASLONG i, j; |
|
|
|
bfloat16_t *a_ptr, *x_ptr; |
|
|
|
FLOAT *y_ptr; |
|
|
|
|
|
|
|
bfloat16x8_t a0, a1, a2, a3, a4, a5, a6, a7; |
|
|
|
bfloat16x8_t t0, t1, t2, t3, t4, t5, t6, t7; |
|
|
|
bfloat16x8_t x_vec; |
|
|
|
float32x4_t y1_vec, y2_vec; |
|
|
|
float32x4_t fp32_low, fp32_high; |
|
|
|
|
|
|
|
float x0, x1, x2, x3, x4, x5, x6, x7; |
|
|
|
bfloat16_t *a_ptr0, *a_ptr1, *a_ptr2, *a_ptr3, *a_ptr4, *a_ptr5, *a_ptr6, |
|
|
|
*a_ptr7; |
|
|
|
|
|
|
|
a_ptr = (bfloat16_t *)a; |
|
|
|
x_ptr = (bfloat16_t *)x; |
|
|
|
|
|
|
|
BLASLONG rest_m = m & 3; |
|
|
|
|
|
|
|
bfloat16x4_t bf16_zero = vreinterpret_bf16_u16(vdup_n_u16(0)); |
|
|
|
bfloat16x8_t bf16_zero_q = vreinterpretq_bf16_u16(vdupq_n_u16(0)); |
|
|
|
|
|
|
|
if (incx == 1 && incy == 1) { |
|
|
|
if (beta != 1) { |
|
|
|
beta_op(y, n, beta); |
|
|
|
} |
|
|
|
|
|
|
|
for (i = 0; i < n / 8; i++) { |
|
|
|
a_ptr0 = a_ptr; |
|
|
|
a_ptr1 = a_ptr0 + lda; |
|
|
|
a_ptr2 = a_ptr1 + lda; |
|
|
|
a_ptr3 = a_ptr2 + lda; |
|
|
|
a_ptr4 = a_ptr3 + lda; |
|
|
|
a_ptr5 = a_ptr4 + lda; |
|
|
|
a_ptr6 = a_ptr5 + lda; |
|
|
|
a_ptr7 = a_ptr6 + lda; |
|
|
|
|
|
|
|
a_ptr += 8 * lda; |
|
|
|
|
|
|
|
y_ptr = y; |
|
|
|
|
|
|
|
x_vec = vld1q_bf16(x_ptr); |
|
|
|
|
|
|
|
if (alpha != 1) { |
|
|
|
fp32_low = vreinterpretq_f32_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(bf16_zero_q), |
|
|
|
vreinterpretq_u16_bf16(x_vec))); |
|
|
|
fp32_high = vreinterpretq_f32_u16( |
|
|
|
vzip2q_u16(vreinterpretq_u16_bf16(bf16_zero_q), |
|
|
|
vreinterpretq_u16_bf16(x_vec))); |
|
|
|
|
|
|
|
fp32_low = vmulq_n_f32(fp32_low, alpha); |
|
|
|
fp32_high = vmulq_n_f32(fp32_high, alpha); |
|
|
|
|
|
|
|
x_vec = |
|
|
|
vcombine_bf16(vcvt_bf16_f32(fp32_low), vcvt_bf16_f32(fp32_high)); |
|
|
|
} |
|
|
|
|
|
|
|
for (j = 0; j < m / 8; j++) { |
|
|
|
a0 = vld1q_bf16(a_ptr0); |
|
|
|
a1 = vld1q_bf16(a_ptr1); |
|
|
|
a2 = vld1q_bf16(a_ptr2); |
|
|
|
a3 = vld1q_bf16(a_ptr3); |
|
|
|
a4 = vld1q_bf16(a_ptr4); |
|
|
|
a5 = vld1q_bf16(a_ptr5); |
|
|
|
a6 = vld1q_bf16(a_ptr6); |
|
|
|
a7 = vld1q_bf16(a_ptr7); |
|
|
|
|
|
|
|
y1_vec = vld1q_f32(y_ptr); |
|
|
|
y2_vec = vld1q_f32(y_ptr + 4); |
|
|
|
|
|
|
|
t0 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); |
|
|
|
t1 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3))); |
|
|
|
t2 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a4), vreinterpretq_u16_bf16(a5))); |
|
|
|
t3 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a6), vreinterpretq_u16_bf16(a7))); |
|
|
|
|
|
|
|
t4 = vreinterpretq_bf16_u16( |
|
|
|
vzip2q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); |
|
|
|
t5 = vreinterpretq_bf16_u16( |
|
|
|
vzip2q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3))); |
|
|
|
t6 = vreinterpretq_bf16_u16( |
|
|
|
vzip2q_u16(vreinterpretq_u16_bf16(a4), vreinterpretq_u16_bf16(a5))); |
|
|
|
t7 = vreinterpretq_bf16_u16( |
|
|
|
vzip2q_u16(vreinterpretq_u16_bf16(a6), vreinterpretq_u16_bf16(a7))); |
|
|
|
|
|
|
|
y1_vec = vbfmlalbq_laneq_f32(y1_vec, t0, x_vec, 0); |
|
|
|
y1_vec = vbfmlaltq_laneq_f32(y1_vec, t0, x_vec, 1); |
|
|
|
y1_vec = vbfmlalbq_laneq_f32(y1_vec, t1, x_vec, 2); |
|
|
|
y1_vec = vbfmlaltq_laneq_f32(y1_vec, t1, x_vec, 3); |
|
|
|
y1_vec = vbfmlalbq_laneq_f32(y1_vec, t2, x_vec, 4); |
|
|
|
y1_vec = vbfmlaltq_laneq_f32(y1_vec, t2, x_vec, 5); |
|
|
|
y1_vec = vbfmlalbq_laneq_f32(y1_vec, t3, x_vec, 6); |
|
|
|
y1_vec = vbfmlaltq_laneq_f32(y1_vec, t3, x_vec, 7); |
|
|
|
|
|
|
|
y2_vec = vbfmlalbq_laneq_f32(y2_vec, t4, x_vec, 0); |
|
|
|
y2_vec = vbfmlaltq_laneq_f32(y2_vec, t4, x_vec, 1); |
|
|
|
y2_vec = vbfmlalbq_laneq_f32(y2_vec, t5, x_vec, 2); |
|
|
|
y2_vec = vbfmlaltq_laneq_f32(y2_vec, t5, x_vec, 3); |
|
|
|
y2_vec = vbfmlalbq_laneq_f32(y2_vec, t6, x_vec, 4); |
|
|
|
y2_vec = vbfmlaltq_laneq_f32(y2_vec, t6, x_vec, 5); |
|
|
|
y2_vec = vbfmlalbq_laneq_f32(y2_vec, t7, x_vec, 6); |
|
|
|
y2_vec = vbfmlaltq_laneq_f32(y2_vec, t7, x_vec, 7); |
|
|
|
|
|
|
|
vst1q_f32(y_ptr, y1_vec); |
|
|
|
vst1q_f32(y_ptr + 4, y2_vec); |
|
|
|
|
|
|
|
a_ptr0 += 8; |
|
|
|
a_ptr1 += 8; |
|
|
|
a_ptr2 += 8; |
|
|
|
a_ptr3 += 8; |
|
|
|
a_ptr4 += 8; |
|
|
|
a_ptr5 += 8; |
|
|
|
a_ptr6 += 8; |
|
|
|
a_ptr7 += 8; |
|
|
|
|
|
|
|
y_ptr += 8; |
|
|
|
} |
|
|
|
|
|
|
|
if (m & 4) { |
|
|
|
bfloat16x4_t a0x4 = vld1_bf16(a_ptr0); |
|
|
|
bfloat16x4_t a1x4 = vld1_bf16(a_ptr1); |
|
|
|
bfloat16x4_t a2x4 = vld1_bf16(a_ptr2); |
|
|
|
bfloat16x4_t a3x4 = vld1_bf16(a_ptr3); |
|
|
|
bfloat16x4_t a4x4 = vld1_bf16(a_ptr4); |
|
|
|
bfloat16x4_t a5x4 = vld1_bf16(a_ptr5); |
|
|
|
bfloat16x4_t a6x4 = vld1_bf16(a_ptr6); |
|
|
|
bfloat16x4_t a7x4 = vld1_bf16(a_ptr7); |
|
|
|
|
|
|
|
y1_vec = vld1q_f32(y_ptr); |
|
|
|
|
|
|
|
a0 = vcombine_bf16(a0x4, bf16_zero); |
|
|
|
a1 = vcombine_bf16(a1x4, bf16_zero); |
|
|
|
a2 = vcombine_bf16(a2x4, bf16_zero); |
|
|
|
a3 = vcombine_bf16(a3x4, bf16_zero); |
|
|
|
a4 = vcombine_bf16(a4x4, bf16_zero); |
|
|
|
a5 = vcombine_bf16(a5x4, bf16_zero); |
|
|
|
a6 = vcombine_bf16(a6x4, bf16_zero); |
|
|
|
a7 = vcombine_bf16(a7x4, bf16_zero); |
|
|
|
|
|
|
|
t0 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); |
|
|
|
t1 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3))); |
|
|
|
t2 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a4), vreinterpretq_u16_bf16(a5))); |
|
|
|
t3 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a6), vreinterpretq_u16_bf16(a7))); |
|
|
|
|
|
|
|
y1_vec = vbfmlalbq_laneq_f32(y1_vec, t0, x_vec, 0); |
|
|
|
y1_vec = vbfmlaltq_laneq_f32(y1_vec, t0, x_vec, 1); |
|
|
|
y1_vec = vbfmlalbq_laneq_f32(y1_vec, t1, x_vec, 2); |
|
|
|
y1_vec = vbfmlaltq_laneq_f32(y1_vec, t1, x_vec, 3); |
|
|
|
y1_vec = vbfmlalbq_laneq_f32(y1_vec, t2, x_vec, 4); |
|
|
|
y1_vec = vbfmlaltq_laneq_f32(y1_vec, t2, x_vec, 5); |
|
|
|
y1_vec = vbfmlalbq_laneq_f32(y1_vec, t3, x_vec, 6); |
|
|
|
y1_vec = vbfmlaltq_laneq_f32(y1_vec, t3, x_vec, 7); |
|
|
|
|
|
|
|
vst1q_f32(y_ptr, y1_vec); |
|
|
|
|
|
|
|
a_ptr0 += 4; |
|
|
|
a_ptr1 += 4; |
|
|
|
a_ptr2 += 4; |
|
|
|
a_ptr3 += 4; |
|
|
|
a_ptr4 += 4; |
|
|
|
a_ptr5 += 4; |
|
|
|
a_ptr6 += 4; |
|
|
|
a_ptr7 += 4; |
|
|
|
|
|
|
|
y_ptr += 4; |
|
|
|
} |
|
|
|
|
|
|
|
if (rest_m) { |
|
|
|
x0 = alpha * BF16_TO_FP32(x_ptr[0]); |
|
|
|
x1 = alpha * BF16_TO_FP32(x_ptr[1]); |
|
|
|
x2 = alpha * BF16_TO_FP32(x_ptr[2]); |
|
|
|
x3 = alpha * BF16_TO_FP32(x_ptr[3]); |
|
|
|
x4 = alpha * BF16_TO_FP32(x_ptr[4]); |
|
|
|
x5 = alpha * BF16_TO_FP32(x_ptr[5]); |
|
|
|
x6 = alpha * BF16_TO_FP32(x_ptr[6]); |
|
|
|
x7 = alpha * BF16_TO_FP32(x_ptr[7]); |
|
|
|
|
|
|
|
for (BLASLONG j = 0; j < rest_m; j++) { |
|
|
|
y_ptr[j] += x0 * BF16_TO_FP32(a_ptr0[j]); |
|
|
|
y_ptr[j] += x1 * BF16_TO_FP32(a_ptr1[j]); |
|
|
|
y_ptr[j] += x2 * BF16_TO_FP32(a_ptr2[j]); |
|
|
|
y_ptr[j] += x3 * BF16_TO_FP32(a_ptr3[j]); |
|
|
|
y_ptr[j] += x4 * BF16_TO_FP32(a_ptr4[j]); |
|
|
|
y_ptr[j] += x5 * BF16_TO_FP32(a_ptr5[j]); |
|
|
|
y_ptr[j] += x6 * BF16_TO_FP32(a_ptr6[j]); |
|
|
|
y_ptr[j] += x7 * BF16_TO_FP32(a_ptr7[j]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
x_ptr += 8; |
|
|
|
} |
|
|
|
|
|
|
|
if (n & 4) { |
|
|
|
a_ptr0 = a_ptr; |
|
|
|
a_ptr1 = a_ptr0 + lda; |
|
|
|
a_ptr2 = a_ptr1 + lda; |
|
|
|
a_ptr3 = a_ptr2 + lda; |
|
|
|
|
|
|
|
a_ptr += 4 * lda; |
|
|
|
|
|
|
|
bfloat16x4_t x_vecx4 = vld1_bf16(x_ptr); |
|
|
|
if (alpha != 1) { |
|
|
|
x_vec = vcombine_bf16(x_vecx4, bf16_zero); |
|
|
|
fp32_low = vreinterpretq_f32_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(bf16_zero_q), |
|
|
|
vreinterpretq_u16_bf16(x_vec))); |
|
|
|
fp32_low = vmulq_n_f32(fp32_low, alpha); |
|
|
|
x_vecx4 = vcvt_bf16_f32(fp32_low); |
|
|
|
} |
|
|
|
|
|
|
|
y_ptr = y; |
|
|
|
for (j = 0; j < m / 8; j++) { |
|
|
|
a0 = vld1q_bf16(a_ptr0); |
|
|
|
a1 = vld1q_bf16(a_ptr1); |
|
|
|
a2 = vld1q_bf16(a_ptr2); |
|
|
|
a3 = vld1q_bf16(a_ptr3); |
|
|
|
|
|
|
|
y1_vec = vld1q_f32(y_ptr); |
|
|
|
y2_vec = vld1q_f32(y_ptr + 4); |
|
|
|
|
|
|
|
t0 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); |
|
|
|
t1 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3))); |
|
|
|
t4 = vreinterpretq_bf16_u16( |
|
|
|
vzip2q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); |
|
|
|
t5 = vreinterpretq_bf16_u16( |
|
|
|
vzip2q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3))); |
|
|
|
|
|
|
|
y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0); |
|
|
|
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1); |
|
|
|
y1_vec = vbfmlalbq_lane_f32(y1_vec, t1, x_vecx4, 2); |
|
|
|
y1_vec = vbfmlaltq_lane_f32(y1_vec, t1, x_vecx4, 3); |
|
|
|
|
|
|
|
y2_vec = vbfmlalbq_lane_f32(y2_vec, t4, x_vecx4, 0); |
|
|
|
y2_vec = vbfmlaltq_lane_f32(y2_vec, t4, x_vecx4, 1); |
|
|
|
y2_vec = vbfmlalbq_lane_f32(y2_vec, t5, x_vecx4, 2); |
|
|
|
y2_vec = vbfmlaltq_lane_f32(y2_vec, t5, x_vecx4, 3); |
|
|
|
|
|
|
|
vst1q_f32(y_ptr, y1_vec); |
|
|
|
vst1q_f32(y_ptr + 4, y2_vec); |
|
|
|
|
|
|
|
a_ptr0 += 8; |
|
|
|
a_ptr1 += 8; |
|
|
|
a_ptr2 += 8; |
|
|
|
a_ptr3 += 8; |
|
|
|
|
|
|
|
y_ptr += 8; |
|
|
|
} |
|
|
|
|
|
|
|
if (m & 4) { |
|
|
|
bfloat16x4_t a0x4 = vld1_bf16(a_ptr0); |
|
|
|
bfloat16x4_t a1x4 = vld1_bf16(a_ptr1); |
|
|
|
bfloat16x4_t a2x4 = vld1_bf16(a_ptr2); |
|
|
|
bfloat16x4_t a3x4 = vld1_bf16(a_ptr3); |
|
|
|
|
|
|
|
y1_vec = vld1q_f32(y_ptr); |
|
|
|
|
|
|
|
a0 = vcombine_bf16(a0x4, bf16_zero); |
|
|
|
a1 = vcombine_bf16(a1x4, bf16_zero); |
|
|
|
a2 = vcombine_bf16(a2x4, bf16_zero); |
|
|
|
a3 = vcombine_bf16(a3x4, bf16_zero); |
|
|
|
|
|
|
|
t0 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); |
|
|
|
t1 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3))); |
|
|
|
|
|
|
|
y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0); |
|
|
|
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1); |
|
|
|
y1_vec = vbfmlalbq_lane_f32(y1_vec, t1, x_vecx4, 2); |
|
|
|
y1_vec = vbfmlaltq_lane_f32(y1_vec, t1, x_vecx4, 3); |
|
|
|
|
|
|
|
vst1q_f32(y_ptr, y1_vec); |
|
|
|
|
|
|
|
a_ptr0 += 4; |
|
|
|
a_ptr1 += 4; |
|
|
|
a_ptr2 += 4; |
|
|
|
a_ptr3 += 4; |
|
|
|
|
|
|
|
y_ptr += 4; |
|
|
|
} |
|
|
|
|
|
|
|
if (rest_m) { |
|
|
|
x0 = alpha * BF16_TO_FP32(x_ptr[0]); |
|
|
|
x1 = alpha * BF16_TO_FP32(x_ptr[1]); |
|
|
|
x2 = alpha * BF16_TO_FP32(x_ptr[2]); |
|
|
|
x3 = alpha * BF16_TO_FP32(x_ptr[3]); |
|
|
|
|
|
|
|
for (BLASLONG j = 0; j < rest_m; j++) { |
|
|
|
y_ptr[j] += x0 * BF16_TO_FP32(a_ptr0[j]); |
|
|
|
y_ptr[j] += x1 * BF16_TO_FP32(a_ptr1[j]); |
|
|
|
y_ptr[j] += x2 * BF16_TO_FP32(a_ptr2[j]); |
|
|
|
y_ptr[j] += x3 * BF16_TO_FP32(a_ptr3[j]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
x_ptr += 4; |
|
|
|
} |
|
|
|
|
|
|
|
if (n & 2) { |
|
|
|
a_ptr0 = a_ptr; |
|
|
|
a_ptr1 = a_ptr0 + lda; |
|
|
|
|
|
|
|
a_ptr += 2 * lda; |
|
|
|
|
|
|
|
bfloat16_t tmp_buffer[4]; |
|
|
|
memset((void*)tmp_buffer, 0, sizeof(bfloat16_t)); |
|
|
|
|
|
|
|
tmp_buffer[0] = x_ptr[0]; |
|
|
|
tmp_buffer[1] = x_ptr[1]; |
|
|
|
|
|
|
|
bfloat16x4_t x_vecx4 = vld1_bf16(tmp_buffer); |
|
|
|
if (alpha != 1) { |
|
|
|
x_vec = vcombine_bf16(x_vecx4, bf16_zero); |
|
|
|
fp32_low = vreinterpretq_f32_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(bf16_zero_q), |
|
|
|
vreinterpretq_u16_bf16(x_vec))); |
|
|
|
fp32_low = vmulq_n_f32(fp32_low, alpha); |
|
|
|
x_vecx4 = vcvt_bf16_f32(fp32_low); |
|
|
|
} |
|
|
|
|
|
|
|
y_ptr = y; |
|
|
|
for (j = 0; j < m / 8; j++) { |
|
|
|
a0 = vld1q_bf16(a_ptr0); |
|
|
|
a1 = vld1q_bf16(a_ptr1); |
|
|
|
|
|
|
|
y1_vec = vld1q_f32(y_ptr); |
|
|
|
y2_vec = vld1q_f32(y_ptr + 4); |
|
|
|
|
|
|
|
t0 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); |
|
|
|
t4 = vreinterpretq_bf16_u16( |
|
|
|
vzip2q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); |
|
|
|
|
|
|
|
y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0); |
|
|
|
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1); |
|
|
|
|
|
|
|
y2_vec = vbfmlalbq_lane_f32(y2_vec, t4, x_vecx4, 0); |
|
|
|
y2_vec = vbfmlaltq_lane_f32(y2_vec, t4, x_vecx4, 1); |
|
|
|
|
|
|
|
vst1q_f32(y_ptr, y1_vec); |
|
|
|
vst1q_f32(y_ptr + 4, y2_vec); |
|
|
|
|
|
|
|
a_ptr0 += 8; |
|
|
|
a_ptr1 += 8; |
|
|
|
|
|
|
|
y_ptr += 8; |
|
|
|
} |
|
|
|
|
|
|
|
if (m & 4) { |
|
|
|
bfloat16x4_t a0x4 = vld1_bf16(a_ptr0); |
|
|
|
bfloat16x4_t a1x4 = vld1_bf16(a_ptr1); |
|
|
|
|
|
|
|
y1_vec = vld1q_f32(y_ptr); |
|
|
|
|
|
|
|
a0 = vcombine_bf16(a0x4, bf16_zero); |
|
|
|
a1 = vcombine_bf16(a1x4, bf16_zero); |
|
|
|
|
|
|
|
t0 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); |
|
|
|
t1 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3))); |
|
|
|
|
|
|
|
y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0); |
|
|
|
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1); |
|
|
|
y1_vec = vbfmlalbq_lane_f32(y1_vec, t1, x_vecx4, 2); |
|
|
|
y1_vec = vbfmlaltq_lane_f32(y1_vec, t1, x_vecx4, 3); |
|
|
|
|
|
|
|
vst1q_f32(y_ptr, y1_vec); |
|
|
|
|
|
|
|
a_ptr0 += 4; |
|
|
|
a_ptr1 += 4; |
|
|
|
a_ptr2 += 4; |
|
|
|
a_ptr3 += 4; |
|
|
|
|
|
|
|
y_ptr += 4; |
|
|
|
} |
|
|
|
|
|
|
|
if (m & 2) { |
|
|
|
float x0, x1; |
|
|
|
x0 = alpha * (BF16_TO_FP32(x_ptr[0])); |
|
|
|
x1 = alpha * (BF16_TO_FP32(x_ptr[1])); |
|
|
|
|
|
|
|
y_ptr[0] += x0 * BF16_TO_FP32(a_ptr0[0]); |
|
|
|
y_ptr[0] += x1 * BF16_TO_FP32(a_ptr1[0]); |
|
|
|
y_ptr[1] += x0 * BF16_TO_FP32(a_ptr0[1]); |
|
|
|
y_ptr[1] += x1 * BF16_TO_FP32(a_ptr1[1]); |
|
|
|
|
|
|
|
a_ptr0 += 2; |
|
|
|
a_ptr1 += 2; |
|
|
|
|
|
|
|
y_ptr += 2; |
|
|
|
} |
|
|
|
|
|
|
|
if (m & 1) { |
|
|
|
float x0, x1; |
|
|
|
x0 = alpha * BF16_TO_FP32(x_ptr[0]); |
|
|
|
x1 = alpha * BF16_TO_FP32(x_ptr[1]); |
|
|
|
|
|
|
|
y_ptr[0] += x0 * BF16_TO_FP32(a_ptr0[0]); |
|
|
|
y_ptr[0] += x1 * BF16_TO_FP32(a_ptr1[0]); |
|
|
|
} |
|
|
|
|
|
|
|
x_ptr += 2; |
|
|
|
} |
|
|
|
|
|
|
|
if (n & 1) { |
|
|
|
x0 = BF16_TO_FP32(x_ptr[0]) * alpha; |
|
|
|
y_ptr = y; |
|
|
|
a_ptr0 = a_ptr; |
|
|
|
|
|
|
|
for (j = 0; j < m; j++) { |
|
|
|
y_ptr[j] += x0 * BF16_TO_FP32(a_ptr0[j]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return (0); |
|
|
|
} |
|
|
|
|
|
|
|
BLASLONG iy = 0; |
|
|
|
for (i = 0; i < m; i++) { |
|
|
|
y[iy] *= beta; |
|
|
|
iy += incy; |
|
|
|
} |
|
|
|
|
|
|
|
for (j = 0; j < n; j++) { |
|
|
|
x0 = alpha * BF16_TO_FP32(*x_ptr); |
|
|
|
iy = 0; |
|
|
|
for (i = 0; i < m; i++) { |
|
|
|
y[iy] += x0 * BF16_TO_FP32(a_ptr[i]); |
|
|
|
iy += incy; |
|
|
|
} |
|
|
|
|
|
|
|
a_ptr += lda; |
|
|
|
x_ptr += incx; |
|
|
|
} |
|
|
|
|
|
|
|
return (0); |
|
|
|
} |