|
|
|
@@ -152,6 +152,14 @@ FORCEINLINE void vec_reduce84_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_a |
|
|
|
vec_reduce44_mma(&out[0], &temp[0], v_alpha, vy0 + 0); |
|
|
|
vec_reduce44_mma(&out[1], &temp[4], v_alpha, vy0 + 1); |
|
|
|
} |
|
|
|
|
|
|
|
FORCEINLINE void vec_reduce88_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0) |
|
|
|
{ |
|
|
|
vec_reduce44_mma(&out[0], &temp[ 0], v_alpha, vy0 + 0); |
|
|
|
vec_reduce44_mma(&out[1], &temp[ 4], v_alpha, vy0 + 1); |
|
|
|
vec_reduce44_mma(&out[2], &temp[ 8], v_alpha, vy0 + 8); |
|
|
|
vec_reduce44_mma(&out[3], &temp[12], v_alpha, vy0 + 9); |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
FORCEINLINE void vec_mult11a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp) |
|
|
|
@@ -341,6 +349,32 @@ FORCEINLINE void vec_load_mult284b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf |
|
|
|
vec_mult44b_mma(out, in0 + 0, in1 + 0, inp + 0); |
|
|
|
vec_mult44b_mma(out, in0 + 2, in1 + 2, inp + 2); |
|
|
|
} |
|
|
|
|
|
|
|
FORCEINLINE void vec_load_mult288a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp) |
|
|
|
{ |
|
|
|
vec_bf16 in0[8], in1[8]; |
|
|
|
|
|
|
|
vec_load4_mma(in0 + 0, in1 + 0, ina + 0, inb + 0); |
|
|
|
vec_load4_mma(in0 + 4, in1 + 4, ina + 4, inb + 4); |
|
|
|
|
|
|
|
vec_mult44a_mma(out + 0, in0 + 0, in1 + 0, inp + 0); |
|
|
|
vec_mult44a_mma(out + 2, in0 + 4, in1 + 4, inp + 0); |
|
|
|
vec_mult44b_mma(out + 0, in0 + 2, in1 + 2, inp + 2); |
|
|
|
vec_mult44b_mma(out + 2, in0 + 6, in1 + 6, inp + 2); |
|
|
|
} |
|
|
|
|
|
|
|
FORCEINLINE void vec_load_mult288b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp) |
|
|
|
{ |
|
|
|
vec_bf16 in0[8], in1[8]; |
|
|
|
|
|
|
|
vec_load4_mma(in0 + 0, in1 + 0, ina + 0, inb + 0); |
|
|
|
vec_load4_mma(in0 + 4, in1 + 4, ina + 4, inb + 4); |
|
|
|
|
|
|
|
vec_mult44b_mma(out + 0, in0 + 0, in1 + 0, inp + 0); |
|
|
|
vec_mult44b_mma(out + 2, in0 + 4, in1 + 4, inp + 0); |
|
|
|
vec_mult44b_mma(out + 0, in0 + 2, in1 + 2, inp + 2); |
|
|
|
vec_mult44b_mma(out + 2, in0 + 6, in1 + 6, inp + 2); |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
FORCEINLINE void vec_loadN_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n) |
|
|
|
@@ -381,49 +415,54 @@ FORCEINLINE void vec_store8_pair(vec_f32 *v_y, vec_f32 *vy0) |
|
|
|
} |
|
|
|
|
|
|
|
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ |
|
|
|
#define VEC_SHIFT(data, shift) vec_sld(data, data, 16 - shift) |
|
|
|
#else |
|
|
|
#define VEC_SHIFT(data, shift) vec_sld(data, data, shift) |
|
|
|
#endif |
|
|
|
#define VEC_SHIFT(data, shift) vec_sldw(data, data, 4 - shift) |
|
|
|
|
|
|
|
typedef __vector unsigned int vec_ui32; |
|
|
|
#define MASK_0 0xf000 |
|
|
|
#define MASK_1 0x0f00 |
|
|
|
#define MASK_2 0x00f0 |
|
|
|
#define MASK_3 0x000f |
|
|
|
#else |
|
|
|
#define VEC_SHIFT(data, shift) vec_sldw(data, data, shift) |
|
|
|
|
|
|
|
static vec_ui32 mask_0 = { 0xffffffff, 0x00000000, 0x00000000, 0x00000000 }; |
|
|
|
static vec_ui32 mask_1 = { 0x00000000, 0xffffffff, 0x00000000, 0x00000000 }; |
|
|
|
static vec_ui32 mask_2 = { 0x00000000, 0x00000000, 0xffffffff, 0x00000000 }; |
|
|
|
static vec_ui32 mask_3 = { 0x00000000, 0x00000000, 0x00000000, 0xffffffff }; |
|
|
|
#define MASK_0 0x000f |
|
|
|
#define MASK_1 0x00f0 |
|
|
|
#define MASK_2 0x0f00 |
|
|
|
#define MASK_3 0xf000 |
|
|
|
#endif |
|
|
|
|
|
|
|
FORCEINLINE void vec_make_mult1(vec_bf16 *v_x0) |
|
|
|
FORCEINLINE void vec_make_mult1(vec_bf16 *v_x0, const bool mask) |
|
|
|
{ |
|
|
|
v_x0[ 0] = vec_and(v_x0[0], (vec_bf16)mask_0); |
|
|
|
if (mask) { |
|
|
|
v_x0[ 0] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_0)); |
|
|
|
} |
|
|
|
|
|
|
|
v_x0[ 1] = VEC_SHIFT(v_x0[ 0], 4); |
|
|
|
v_x0[ 2] = VEC_SHIFT(v_x0[ 0], 8); |
|
|
|
v_x0[ 3] = VEC_SHIFT(v_x0[ 0], 12); |
|
|
|
v_x0[ 1] = VEC_SHIFT(v_x0[ 0], 1); |
|
|
|
v_x0[ 2] = VEC_SHIFT(v_x0[ 0], 2); |
|
|
|
v_x0[ 3] = VEC_SHIFT(v_x0[ 0], 3); |
|
|
|
} |
|
|
|
|
|
|
|
FORCEINLINE void vec_make_mult2(vec_bf16 *v_x0) |
|
|
|
{ |
|
|
|
v_x0[ 5] = vec_and(v_x0[0], (vec_bf16)mask_1); |
|
|
|
vec_make_mult1(v_x0); |
|
|
|
v_x0[ 5] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_1)); |
|
|
|
vec_make_mult1(v_x0, true); |
|
|
|
|
|
|
|
v_x0[ 4] = VEC_SHIFT(v_x0[ 5], 12); |
|
|
|
v_x0[ 6] = VEC_SHIFT(v_x0[ 5], 4); |
|
|
|
v_x0[ 7] = VEC_SHIFT(v_x0[ 5], 8); |
|
|
|
v_x0[ 4] = VEC_SHIFT(v_x0[ 5], 3); |
|
|
|
v_x0[ 6] = VEC_SHIFT(v_x0[ 5], 1); |
|
|
|
v_x0[ 7] = VEC_SHIFT(v_x0[ 5], 2); |
|
|
|
} |
|
|
|
|
|
|
|
FORCEINLINE void vec_make_mult4(vec_bf16 *v_x0) |
|
|
|
{ |
|
|
|
v_x0[10] = vec_and(v_x0[0], (vec_bf16)mask_2); |
|
|
|
v_x0[15] = vec_and(v_x0[0], (vec_bf16)mask_3); |
|
|
|
v_x0[10] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_2)); |
|
|
|
v_x0[15] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_3)); |
|
|
|
vec_make_mult2(v_x0); |
|
|
|
|
|
|
|
v_x0[ 8] = VEC_SHIFT(v_x0[10], 8); |
|
|
|
v_x0[ 9] = VEC_SHIFT(v_x0[10], 12); |
|
|
|
v_x0[11] = VEC_SHIFT(v_x0[10], 4); |
|
|
|
v_x0[12] = VEC_SHIFT(v_x0[15], 4); |
|
|
|
v_x0[13] = VEC_SHIFT(v_x0[15], 8); |
|
|
|
v_x0[14] = VEC_SHIFT(v_x0[15], 12); |
|
|
|
v_x0[ 8] = VEC_SHIFT(v_x0[10], 2); |
|
|
|
v_x0[ 9] = VEC_SHIFT(v_x0[10], 3); |
|
|
|
v_x0[11] = VEC_SHIFT(v_x0[10], 1); |
|
|
|
v_x0[12] = VEC_SHIFT(v_x0[15], 1); |
|
|
|
v_x0[13] = VEC_SHIFT(v_x0[15], 2); |
|
|
|
v_x0[14] = VEC_SHIFT(v_x0[15], 3); |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
|