|
|
|
@@ -28,7 +28,9 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
|
|
|
#ifndef SBGEMV_N_MMA_C |
|
|
|
#define SBGEMV_N_MMA_C |
|
|
|
|
|
|
|
#if !defined(_AIX) || defined(__clang__) |
|
|
|
#define USE_BFGEMV_N_MMA |
|
|
|
#endif |
|
|
|
|
|
|
|
#ifdef USE_BFGEMV_N_MMA |
|
|
|
#include "sbgemv_common_power10.c" |
|
|
|
@@ -47,7 +49,7 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA |
|
|
|
{ |
|
|
|
IFLOAT *a0; |
|
|
|
__vector_quad temp[2*4]; |
|
|
|
vec_f32 temp0[8*4], vy0[2*4]; |
|
|
|
vec_f32 temp0[8*4]; |
|
|
|
vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; |
|
|
|
|
|
|
|
a0 = ap[0]; |
|
|
|
@@ -55,26 +57,61 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA |
|
|
|
vec_bf16 *va0 = (vec_bf16 *)a0; |
|
|
|
|
|
|
|
vec_bf16 *x_bf = (vec_bf16 *)(xo); |
|
|
|
vec_bf16 v_x0 = vec_loadN(x_bf, 1); |
|
|
|
|
|
|
|
vec_f32 *v_y = (vec_f32 *)y; |
|
|
|
BLASLONG n8 = n / 8; |
|
|
|
BLASLONG i = 0; |
|
|
|
|
|
|
|
#ifdef USE_MERGE_MMA |
|
|
|
vec_bf16 v_x0[4]; |
|
|
|
v_x0[0] = vec_loadN(x_bf, 1); |
|
|
|
vec_f32 vy0[2*4*2]; |
|
|
|
|
|
|
|
vec_make_mult1(v_x0); |
|
|
|
|
|
|
|
for (; i + 8 <= n8; i += 8) { |
|
|
|
vec_load8_pair(vy0, &v_y[(i * 2) + 0]); |
|
|
|
|
|
|
|
vec_load_mult184_mma(&temp[0], &va0[i + 0], &v_x0[ 0]); |
|
|
|
vec_load_mult184_mma(&temp[2], &va0[i + 4], &v_x0[ 0]); |
|
|
|
|
|
|
|
vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); |
|
|
|
vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8); |
|
|
|
|
|
|
|
vec_store8_pair(&v_y[(i * 2) + 0], vy0); |
|
|
|
} |
|
|
|
|
|
|
|
if (n8 & 4) { |
|
|
|
vec_load4_pair(vy0, &v_y[(i * 2) + 0]); |
|
|
|
|
|
|
|
vec_load_mult184_mma(&temp[0], &va0[i + 0], &v_x0[ 0]); |
|
|
|
|
|
|
|
vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
vec_store4_pair(&v_y[(i * 2) + 0], vy0); |
|
|
|
|
|
|
|
i += 4; |
|
|
|
} |
|
|
|
#else |
|
|
|
vec_bf16 v_x0[1]; |
|
|
|
v_x0[0] = vec_loadN(x_bf, 1); |
|
|
|
vec_f32 vy0[2*4]; |
|
|
|
|
|
|
|
for (; i + 4 <= n8; i += 4) { |
|
|
|
vec_load4_pair(vy0, &v_y[(i * 2) + 0]); |
|
|
|
|
|
|
|
vec_load_mult18_mma(&temp[0], &va0[i + 0], v_x0); |
|
|
|
vec_load_mult18_mma(&temp[0], &va0[i + 0], v_x0[ 0]); |
|
|
|
|
|
|
|
vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
vec_store4_pair(&v_y[(i * 2) + 0], vy0); |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
for (; i < n8; i++) { |
|
|
|
vec_load_pair(vy0, &v_y[(i * 2) + 0]); |
|
|
|
|
|
|
|
vec_load_mult12_mma(&temp[0], &va0[i], v_x0); |
|
|
|
vec_load_mult12_mma(&temp[0], &va0[i], v_x0[ 0]); |
|
|
|
|
|
|
|
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
@@ -86,7 +123,7 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA |
|
|
|
BLASLONG n3 = n & 3; |
|
|
|
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); |
|
|
|
|
|
|
|
vec_loadN_mult12_mma(&temp[0], &va0[i], v_x0, n); |
|
|
|
vec_loadN_mult12_mma(&temp[0], &va0[i], v_x0[ 0], n); |
|
|
|
|
|
|
|
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
@@ -94,7 +131,7 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA |
|
|
|
} else if (n) { |
|
|
|
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); |
|
|
|
|
|
|
|
vec_loadN_mult11_mma(&temp[0], &va0[i], v_x0, n); |
|
|
|
vec_loadN_mult11_mma(&temp[0], &va0[i], v_x0[ 0], n); |
|
|
|
|
|
|
|
vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
@@ -106,7 +143,7 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA |
|
|
|
{ |
|
|
|
IFLOAT *a0, *a1; |
|
|
|
__vector_quad temp[2*4]; |
|
|
|
vec_f32 temp0[8*4], vy0[2*4]; |
|
|
|
vec_f32 temp0[8*4]; |
|
|
|
vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; |
|
|
|
|
|
|
|
a0 = ap[0]; |
|
|
|
@@ -116,26 +153,61 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA |
|
|
|
vec_bf16 *va1 = (vec_bf16 *)a1; |
|
|
|
|
|
|
|
vec_bf16 *x_bf = (vec_bf16 *)(xo); |
|
|
|
vec_bf16 v_x0 = vec_loadN(x_bf, 2); |
|
|
|
|
|
|
|
vec_f32 *v_y = (vec_f32 *)y; |
|
|
|
BLASLONG n8 = n / 8; |
|
|
|
BLASLONG i = 0; |
|
|
|
|
|
|
|
#ifdef USE_MERGE_MMA |
|
|
|
vec_bf16 v_x0[4]; |
|
|
|
vec_f32 vy0[2*4*2]; |
|
|
|
v_x0[0] = vec_loadN(x_bf, 2); |
|
|
|
|
|
|
|
vec_make_mult1(v_x0); |
|
|
|
|
|
|
|
for (; i + 8 <= n8; i += 8) { |
|
|
|
vec_load8_pair(vy0, &v_y[(i * 2) + 0]); |
|
|
|
|
|
|
|
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); |
|
|
|
vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]); |
|
|
|
|
|
|
|
vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); |
|
|
|
vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8); |
|
|
|
|
|
|
|
vec_store8_pair(&v_y[(i * 2) + 0], vy0); |
|
|
|
} |
|
|
|
|
|
|
|
if (n8 & 4) { |
|
|
|
vec_load4_pair(vy0, &v_y[(i * 2) + 0]); |
|
|
|
|
|
|
|
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); |
|
|
|
|
|
|
|
vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
vec_store4_pair(&v_y[(i * 2) + 0], vy0); |
|
|
|
|
|
|
|
i += 4; |
|
|
|
} |
|
|
|
#else |
|
|
|
vec_bf16 v_x0[1]; |
|
|
|
vec_f32 vy0[2*4]; |
|
|
|
v_x0[0] = vec_loadN(x_bf, 2); |
|
|
|
|
|
|
|
for (; i + 4 <= n8; i += 4) { |
|
|
|
vec_load4_pair(vy0, &v_y[(i * 2) + 0]); |
|
|
|
|
|
|
|
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0); |
|
|
|
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]); |
|
|
|
|
|
|
|
vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
vec_store4_pair(&v_y[(i * 2) + 0], vy0); |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
for (; i < n8; i++) { |
|
|
|
vec_load_pair(vy0, &v_y[(i * 2) + 0]); |
|
|
|
|
|
|
|
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0); |
|
|
|
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]); |
|
|
|
|
|
|
|
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
@@ -147,7 +219,7 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA |
|
|
|
BLASLONG n3 = n & 3; |
|
|
|
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); |
|
|
|
|
|
|
|
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0, n); |
|
|
|
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); |
|
|
|
|
|
|
|
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
@@ -155,7 +227,7 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA |
|
|
|
} else if (n) { |
|
|
|
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); |
|
|
|
|
|
|
|
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0, n); |
|
|
|
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); |
|
|
|
|
|
|
|
vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
@@ -167,7 +239,7 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA |
|
|
|
{ |
|
|
|
IFLOAT *a0, *a1, *a2, *a3; |
|
|
|
__vector_quad temp[2*4]; |
|
|
|
vec_f32 temp0[8*4], vy0[2*4]; |
|
|
|
vec_f32 temp0[8*4]; |
|
|
|
vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; |
|
|
|
|
|
|
|
a0 = ap[0]; |
|
|
|
@@ -181,30 +253,68 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA |
|
|
|
vec_bf16 *va3 = (vec_bf16 *)a3; |
|
|
|
|
|
|
|
vec_bf16 *x_bf = (vec_bf16 *)(xo); |
|
|
|
vec_bf16 v_x00 = vec_loadN(x_bf, 4); |
|
|
|
|
|
|
|
vec_bf16 v_x01 = (vec_bf16)vec_splat((vec_f32)v_x00, 1); |
|
|
|
|
|
|
|
vec_f32 *v_y = (vec_f32 *)y; |
|
|
|
BLASLONG n8 = n / 8; |
|
|
|
BLASLONG i = 0; |
|
|
|
|
|
|
|
#ifdef USE_MERGE_MMA |
|
|
|
vec_bf16 v_x0[8]; |
|
|
|
vec_f32 vy0[2*4*2]; |
|
|
|
v_x0[0] = vec_loadN(x_bf, 4); |
|
|
|
|
|
|
|
vec_make_mult2(v_x0); |
|
|
|
|
|
|
|
for (; i + 8 <= n8; i += 8) { |
|
|
|
vec_load8_pair(vy0, &v_y[(i * 2) + 0]); |
|
|
|
|
|
|
|
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); |
|
|
|
vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); |
|
|
|
vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]); |
|
|
|
vec_load_mult284b_mma(&temp[2], &va2[i + 4], &va3[i + 4], &v_x0[ 4]); |
|
|
|
|
|
|
|
vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); |
|
|
|
vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8); |
|
|
|
|
|
|
|
vec_store8_pair(&v_y[(i * 2) + 0], vy0); |
|
|
|
} |
|
|
|
|
|
|
|
if (n8 & 4) { |
|
|
|
vec_load4_pair(vy0, &v_y[(i * 2) + 0]); |
|
|
|
|
|
|
|
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); |
|
|
|
vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); |
|
|
|
|
|
|
|
vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
vec_store4_pair(&v_y[(i * 2) + 0], vy0); |
|
|
|
|
|
|
|
i += 4; |
|
|
|
} |
|
|
|
#else |
|
|
|
vec_bf16 v_x0[5]; |
|
|
|
vec_f32 vy0[2*4]; |
|
|
|
v_x0[0] = vec_loadN(x_bf, 4); |
|
|
|
|
|
|
|
v_x0[ 4] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 1); |
|
|
|
|
|
|
|
for (; i + 4 <= n8; i += 4) { |
|
|
|
vec_load4_pair(vy0, &v_y[(i * 2) + 0]); |
|
|
|
|
|
|
|
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x00); |
|
|
|
vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x01); |
|
|
|
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]); |
|
|
|
vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x0[ 4]); |
|
|
|
|
|
|
|
vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
vec_store4_pair(&v_y[(i * 2) + 0], vy0); |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
for (; i < n8; i++) { |
|
|
|
vec_load_pair(vy0, &v_y[(i * 2) + 0]); |
|
|
|
|
|
|
|
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00); |
|
|
|
vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01); |
|
|
|
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]); |
|
|
|
vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4]); |
|
|
|
|
|
|
|
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
@@ -216,8 +326,8 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA |
|
|
|
BLASLONG n3 = n & 3; |
|
|
|
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); |
|
|
|
|
|
|
|
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00, n); |
|
|
|
vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01, n); |
|
|
|
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); |
|
|
|
vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); |
|
|
|
|
|
|
|
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
@@ -225,8 +335,8 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA |
|
|
|
} else if (n) { |
|
|
|
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); |
|
|
|
|
|
|
|
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x00, n); |
|
|
|
vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x01, n); |
|
|
|
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); |
|
|
|
vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); |
|
|
|
|
|
|
|
vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
@@ -239,7 +349,7 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS |
|
|
|
{ |
|
|
|
IFLOAT *a0, *a1, *a2, *a3, *b0, *b1, *b2, *b3; |
|
|
|
__vector_quad temp[2*4]; |
|
|
|
vec_f32 temp0[8*4], vy0[2*4]; |
|
|
|
vec_f32 temp0[8*4]; |
|
|
|
vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; |
|
|
|
|
|
|
|
a0 = ap[0]; |
|
|
|
@@ -261,36 +371,80 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS |
|
|
|
vec_bf16 *vb3 = (vec_bf16 *)b3; |
|
|
|
|
|
|
|
vec_bf16 *x_bf = (vec_bf16 *)(xo); |
|
|
|
vec_bf16 v_x00 = (vec_bf16)vec_load_vec(x_bf); |
|
|
|
|
|
|
|
vec_bf16 v_x01 = (vec_bf16)vec_splat((vec_f32)v_x00, 1); |
|
|
|
vec_bf16 v_x02 = (vec_bf16)vec_splat((vec_f32)v_x00, 2); |
|
|
|
vec_bf16 v_x03 = (vec_bf16)vec_splat((vec_f32)v_x00, 3); |
|
|
|
|
|
|
|
vec_f32 *v_y = (vec_f32 *)y; |
|
|
|
BLASLONG n8 = n / 8; |
|
|
|
BLASLONG i = 0; |
|
|
|
|
|
|
|
#ifdef USE_MERGE_MMA |
|
|
|
vec_bf16 v_x0[16]; |
|
|
|
vec_f32 vy0[2*4*2]; |
|
|
|
v_x0[0] = (vec_bf16)vec_load_vec(x_bf); |
|
|
|
|
|
|
|
vec_make_mult4(v_x0); |
|
|
|
|
|
|
|
for (; i + 8 <= n8; i += 8) { |
|
|
|
vec_load8_pair(vy0, &v_y[(i * 2) + 0]); |
|
|
|
|
|
|
|
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); |
|
|
|
vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); |
|
|
|
vec_load_mult284b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]); |
|
|
|
vec_load_mult284b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]); |
|
|
|
vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]); |
|
|
|
vec_load_mult284b_mma(&temp[2], &va2[i + 4], &va3[i + 4], &v_x0[ 4]); |
|
|
|
vec_load_mult284b_mma(&temp[2], &vb0[i + 4], &vb1[i + 4], &v_x0[ 8]); |
|
|
|
vec_load_mult284b_mma(&temp[2], &vb2[i + 4], &vb3[i + 4], &v_x0[12]); |
|
|
|
|
|
|
|
vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); |
|
|
|
vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8); |
|
|
|
|
|
|
|
vec_store8_pair(&v_y[(i * 2) + 0], vy0); |
|
|
|
} |
|
|
|
|
|
|
|
if (n8 & 4) { |
|
|
|
vec_load4_pair(vy0, &v_y[(i * 2) + 0]); |
|
|
|
|
|
|
|
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); |
|
|
|
vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); |
|
|
|
vec_load_mult284b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]); |
|
|
|
vec_load_mult284b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]); |
|
|
|
|
|
|
|
vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
vec_store4_pair(&v_y[(i * 2) + 0], vy0); |
|
|
|
|
|
|
|
i += 4; |
|
|
|
} |
|
|
|
#else |
|
|
|
vec_bf16 v_x0[13]; |
|
|
|
vec_f32 vy0[2*4]; |
|
|
|
v_x0[0] = (vec_bf16)vec_load_vec(x_bf); |
|
|
|
|
|
|
|
v_x0[ 4] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 1); |
|
|
|
v_x0[ 8] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 2); |
|
|
|
v_x0[12] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 3); |
|
|
|
|
|
|
|
for (; i + 4 <= n8; i += 4) { |
|
|
|
vec_load4_pair(vy0, &v_y[(i * 2) + 0]); |
|
|
|
|
|
|
|
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x00); |
|
|
|
vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x01); |
|
|
|
vec_load_mult28b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], v_x02); |
|
|
|
vec_load_mult28b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], v_x03); |
|
|
|
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]); |
|
|
|
vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x0[ 4]); |
|
|
|
vec_load_mult28b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], v_x0[ 8]); |
|
|
|
vec_load_mult28b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], v_x0[12]); |
|
|
|
|
|
|
|
vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
vec_store4_pair(&v_y[(i * 2) + 0], vy0); |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
for (; i < n8; i++) { |
|
|
|
vec_load_pair(vy0, &v_y[(i * 2) + 0]); |
|
|
|
|
|
|
|
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00); |
|
|
|
vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01); |
|
|
|
vec_load_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x02); |
|
|
|
vec_load_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x03); |
|
|
|
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]); |
|
|
|
vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4]); |
|
|
|
vec_load_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8]); |
|
|
|
vec_load_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12]); |
|
|
|
|
|
|
|
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
@@ -302,10 +456,10 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS |
|
|
|
BLASLONG n3 = n & 3; |
|
|
|
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); |
|
|
|
|
|
|
|
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00, n); |
|
|
|
vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01, n); |
|
|
|
vec_loadN_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x02, n); |
|
|
|
vec_loadN_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x03, n); |
|
|
|
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); |
|
|
|
vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); |
|
|
|
vec_loadN_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8], n); |
|
|
|
vec_loadN_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12], n); |
|
|
|
|
|
|
|
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
@@ -313,10 +467,10 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS |
|
|
|
} else if (n) { |
|
|
|
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); |
|
|
|
|
|
|
|
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x00, n); |
|
|
|
vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x01, n); |
|
|
|
vec_loadN_mult11b_mma(&temp[0], &vb0[i], &vb1[i], v_x02, n); |
|
|
|
vec_loadN_mult11b_mma(&temp[0], &vb2[i], &vb3[i], v_x03, n); |
|
|
|
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); |
|
|
|
vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); |
|
|
|
vec_loadN_mult11b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8], n); |
|
|
|
vec_loadN_mult11b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12], n); |
|
|
|
|
|
|
|
vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); |
|
|
|
|
|
|
|
|