| @@ -48,22 +48,20 @@ FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 in | |||
| FORCEINLINE void vec_load_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp) | |||
| { | |||
| vec_bf16 in01 = (vec_bf16)vec_load_vec(in0); | |||
| vec_bf16 in11 = (vec_bf16)vec_load_vec(in1); | |||
| __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); | |||
| vec_load_mult_mma(out, in0, inp); | |||
| __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); | |||
| } | |||
| FORCEINLINE void vec_load_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp) | |||
| { | |||
| vec_bf16 in01 = (vec_bf16)vec_load_vec(in0); | |||
| vec_bf16 in11 = (vec_bf16)vec_load_vec(in1); | |||
| vec_bf16 in21 = (vec_bf16)vec_load_vec(in2); | |||
| vec_bf16 in31 = (vec_bf16)vec_load_vec(in3); | |||
| __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); | |||
| __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); | |||
| vec_load_mult12a_mma(out, in0, in1, inp); | |||
| __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp); | |||
| __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp); | |||
| } | |||
| @@ -78,6 +76,12 @@ FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 * | |||
| __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]); | |||
| } | |||
| FORCEINLINE void vec_mult2d_mma(__vector_quad *out, vec_bf16 *in01, vec_bf16 *in11, vec_bf16 *inp) | |||
| { | |||
| __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); | |||
| __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); | |||
| } | |||
| FORCEINLINE void vec_load_mult22_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp) | |||
| { | |||
| vec_bf16 in01[2], in11[2]; | |||
| @@ -85,10 +89,8 @@ FORCEINLINE void vec_load_mult22_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 | |||
| vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0); | |||
| vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1); | |||
| __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); | |||
| __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); | |||
| __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]); | |||
| __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]); | |||
| vec_mult2d_mma(out, in01 + 0, in11 + 0, inp + 0); | |||
| vec_mult2d_mma(out, in01 + 1, in11 + 1, inp + 1); | |||
| } | |||
| FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp) | |||
| @@ -100,26 +102,22 @@ FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 | |||
| vec_load_pair((vec_f32 *)in21, (vec_f32 *)in2); | |||
| vec_load_pair((vec_f32 *)in31, (vec_f32 *)in3); | |||
| __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); | |||
| __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); | |||
| __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]); | |||
| __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]); | |||
| __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]); | |||
| __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]); | |||
| __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]); | |||
| __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]); | |||
| vec_mult2d_mma(out + 0, in01 + 0, in11 + 0, inp + 0); | |||
| vec_mult2d_mma(out + 2, in21 + 0, in31 + 0, inp + 0); | |||
| vec_mult2d_mma(out + 0, in01 + 1, in11 + 1, inp + 1); | |||
| vec_mult2d_mma(out + 2, in21 + 1, in31 + 1, inp + 1); | |||
| } | |||
| FORCEINLINE void vec_load_mult4_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp) | |||
| { | |||
| vec_bf16 in0[4]; | |||
| vec_bf16 in0[2]; | |||
| vec_load_pair2(in0, in); | |||
| vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 2)); | |||
| __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[0]); | |||
| __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]); | |||
| __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[2], (vec_uc8)inp[2]); | |||
| __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[3], (vec_uc8)inp[3]); | |||
| vec_load_mult2_mma(out, in + 0, inp + 0); | |||
| __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[2]); | |||
| __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[3]); | |||
| } | |||
| FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp) | |||
| @@ -129,14 +127,16 @@ FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 | |||
| vec_load_pair2(in01, in0); | |||
| vec_load_pair2(in11, in1); | |||
| __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); | |||
| __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); | |||
| __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]); | |||
| __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]); | |||
| __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]); | |||
| __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]); | |||
| __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]); | |||
| __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]); | |||
| vec_mult2d_mma(out, in01 + 0, in11 + 0, inp + 0); | |||
| vec_mult2d_mma(out, in01 + 1, in11 + 1, inp + 1); | |||
| vec_mult2d_mma(out, in01 + 2, in11 + 2, inp + 2); | |||
| vec_mult2d_mma(out, in01 + 3, in11 + 3, inp + 3); | |||
| } | |||
| FORCEINLINE void vec_mult4d_mma(__vector_quad *out, vec_bf16 *in01, vec_bf16 *in11, vec_bf16 *in21, vec_bf16 *in31, vec_bf16 *inp) | |||
| { | |||
| vec_mult2d_mma(out + 0, in01, in11, inp); | |||
| vec_mult2d_mma(out + 2, in21, in31, inp); | |||
| } | |||
| FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp) | |||
| @@ -148,22 +148,10 @@ FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 | |||
| vec_load_pair2(in21, in2); | |||
| vec_load_pair2(in31, in3); | |||
| __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); | |||
| __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); | |||
| __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]); | |||
| __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]); | |||
| __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]); | |||
| __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]); | |||
| __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]); | |||
| __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]); | |||
| __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]); | |||
| __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]); | |||
| __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[2], (vec_uc8)inp[2]); | |||
| __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[2], (vec_uc8)inp[2]); | |||
| __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]); | |||
| __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]); | |||
| __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[3], (vec_uc8)inp[3]); | |||
| __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[3], (vec_uc8)inp[3]); | |||
| vec_mult4d_mma(out, in01 + 0, in11 + 0, in21 + 0, in31 + 0, inp + 0); | |||
| vec_mult4d_mma(out, in01 + 1, in11 + 1, in21 + 1, in31 + 1, inp + 1); | |||
| vec_mult4d_mma(out, in01 + 2, in11 + 2, in21 + 2, in31 + 2, inp + 2); | |||
| vec_mult4d_mma(out, in01 + 3, in11 + 3, in21 + 3, in31 + 3, inp + 3); | |||
| } | |||
| FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n) | |||
| @@ -175,22 +163,20 @@ FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 i | |||
| FORCEINLINE void vec_loadN_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp, BLASLONG n) | |||
| { | |||
| vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n); | |||
| vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n); | |||
| __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); | |||
| vec_loadN_mult_mma(out, in0, inp, n); | |||
| __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); | |||
| } | |||
| FORCEINLINE void vec_loadN_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp, BLASLONG n) | |||
| { | |||
| vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n); | |||
| vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n); | |||
| vec_bf16 in21 = (vec_bf16)vec_loadN(in2, n); | |||
| vec_bf16 in31 = (vec_bf16)vec_loadN(in3, n); | |||
| __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); | |||
| __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); | |||
| vec_loadN_mult12a_mma(out, in0, in1, inp, n); | |||
| __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp); | |||
| __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp); | |||
| } | |||