| @@ -64,6 +64,28 @@ typedef struct { | |||
| cfg.tile_colsb[7] = n * 4; \ | |||
| _tile_loadconfig(&cfg); | |||
| /* CONFIG for handling k2 and odd tail at the same time | |||
| * tile0 -- A (m x 2k) | |||
| * tile1 -- A (m x 1) | |||
| * tile2 -- B (2k x n) | |||
| * tile3 -- B (1 x n) | |||
| * tile4 -- C (m x n) | |||
| */ | |||
| #define TCONF_TAIL(cfg, m, n, k2) \ | |||
| memset(&cfg, 0, sizeof(tilecfg)); \ | |||
| cfg.palette_id = 1; \ | |||
| cfg.tile_rows[0] = m; \ | |||
| cfg.tile_rows[1] = m; \ | |||
| cfg.tile_rows[2] = k2>>1; \ | |||
| cfg.tile_rows[3] = 1; \ | |||
| cfg.tile_rows[4] = m; \ | |||
| cfg.tile_colsb[0] = k2<<1; \ | |||
| cfg.tile_colsb[1] = 4; \ | |||
| cfg.tile_colsb[2] = n * 4; \ | |||
| cfg.tile_colsb[3] = n * 4; \ | |||
| cfg.tile_colsb[4] = n * 4; \ | |||
| _tile_loadconfig(&cfg); | |||
| #define T_A0 0 | |||
| #define T_A1 1 | |||
| #define T_B0 2 | |||
| @@ -104,6 +126,7 @@ typedef struct { | |||
| #define LOAD_C(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4) | |||
| #define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N) | |||
| #define MATMUL_TAIL(M, N) _tile_dpbf16ps(T_C00, T_A##M, T_B##N) | |||
| #define STORE_C(M, N) _tile_stored(T_C##M##N, ptr_c##M##N, ldc * 4) | |||
| @@ -275,86 +298,123 @@ tail_k: | |||
| // process for k < 32 | |||
| BLASLONG k32 = k & ~31; | |||
| BLASLONG k2 = k & ~1; | |||
| int remain_k2 = k2 - k32; | |||
| if (remain_k2 > 0) { | |||
| if (k32 != k) { | |||
| int remain_k2 = k2 - k32; | |||
| m_count = m; | |||
| ptr_a = A; | |||
| ptr_c = C; | |||
| for (; m_count > 0; m_count -= 16) { | |||
| int tail_m = (m_count > 16) ? 16: m_count; | |||
| __mmask16 amask = (1UL << tail_m) - 1; | |||
| ptr_a0 = ptr_a + tail_m * k32; | |||
| ptr_a += tail_m * k; | |||
| ptr_b = B; | |||
| ptr_c00 = ptr_c; | |||
| ptr_c += tail_m * ldc; | |||
| n_count = n; | |||
| lda = remain_k2; | |||
| ldb = 32; | |||
| TCONF(cfg, tail_m, 16, remain_k2); | |||
| for (; n_count > 15; n_count -= 16) { | |||
| ptr_b0 = ptr_b + 16 * k32; | |||
| LOAD_C(0, 0); | |||
| LOAD_A(0, x); | |||
| LOAD_B(x, 0); | |||
| MATMUL(0, 0); | |||
| STORE_C(0, 0); | |||
| ptr_b += 16 * k; | |||
| ptr_c00 += 16; | |||
| } | |||
| if (n_count > 0) { | |||
| int tail_n = (n_count > 16) ? 16: n_count; | |||
| __mmask16 bmask = (1UL << tail_n) - 1; | |||
| ptr_b0 = ptr_b + tail_n * k32; | |||
| ldb = 2 * tail_n; | |||
| TCONF(cfg, tail_m, tail_n, remain_k2); | |||
| LOAD_C(0, 0); | |||
| LOAD_A(0, x); | |||
| LOAD_B(x, 0); | |||
| MATMUL(0, 0); | |||
| STORE_C(0, 0); | |||
| if (remain_k2 > 0 && k2 != k) { // k%32 = 2x + 1 (x != 0) | |||
| for (; m_count > 0; m_count -= 16) { | |||
| int tail_m = (m_count > 16) ? 16: m_count; | |||
| __mmask16 amask = (1UL << tail_m) - 1; | |||
| ptr_a0 = ptr_a + tail_m * k32; | |||
| ptr_a1 = ptr_a + tail_m * k2; | |||
| ptr_a += tail_m * k; | |||
| ptr_b = B; | |||
| ptr_c00 = ptr_c; | |||
| ptr_c += tail_m * ldc; | |||
| n_count = n; | |||
| lda = remain_k2; | |||
| ldb = 32; | |||
| TCONF_TAIL(cfg, tail_m, 16, remain_k2); | |||
| for (; n_count > 15; n_count -= 16) { | |||
| ptr_b0 = ptr_b + 16 * k32; | |||
| ptr_b1 = ptr_b + 16 * k2; | |||
| LOAD_C(0, 0); | |||
| LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); | |||
| LOAD_B(x, 0); LOAD_B_TAIL(x, 1); | |||
| MATMUL(0, 0); MATMUL_TAIL(1, 1); | |||
| STORE_C(0, 0); | |||
| ptr_b += 16 * k; | |||
| ptr_c00 += 16; | |||
| } | |||
| if (n_count > 0) { | |||
| int tail_n = (n_count > 16) ? 16: n_count; | |||
| __mmask16 bmask = (1UL << tail_n) - 1; | |||
| ptr_b0 = ptr_b + tail_n * k32; | |||
| ptr_b1 = ptr_b + tail_n * k2; | |||
| ldb = 2 * tail_n; | |||
| TCONF_TAIL(cfg, tail_m, tail_n, remain_k2); | |||
| LOAD_C(0, 0); | |||
| LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); | |||
| LOAD_B(x, 0); MASK_LOAD_B_TAIL(x, 1); | |||
| MATMUL(0, 0); MATMUL_TAIL(1, 1); | |||
| STORE_C(0, 0); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (k2 != k) { | |||
| m_count = m; | |||
| ptr_a = A; | |||
| ptr_c = C; | |||
| for (; m_count > 0; m_count -= 16) { | |||
| int tail_m = (m_count > 16) ? 16: m_count; | |||
| __mmask16 amask = (1UL << tail_m) - 1; | |||
| ptr_a0 = ptr_a + tail_m * k2; | |||
| ptr_a += tail_m * k; | |||
| ptr_b = B; | |||
| ptr_c00 = ptr_c; | |||
| ptr_c += tail_m * ldc; | |||
| n_count = n; | |||
| TCONF(cfg, tail_m, 16, 2); | |||
| for (; n_count > 15; n_count -= 16) { | |||
| ptr_b0 = ptr_b + 16 * k2; | |||
| LOAD_C(0, 0); | |||
| MASK_LOAD_A_TAIL(0, x); | |||
| LOAD_B_TAIL(x, 0); | |||
| MATMUL(0, 0); | |||
| STORE_C(0, 0); | |||
| ptr_b += 16 * k; | |||
| ptr_c00 += 16; | |||
| } else if (remain_k2 > 0) { // k%32 = 2x | |||
| for (; m_count > 0; m_count -= 16) { | |||
| int tail_m = (m_count > 16) ? 16: m_count; | |||
| ptr_a0 = ptr_a + tail_m * k32; | |||
| ptr_a += tail_m * k; | |||
| ptr_b = B; | |||
| ptr_c00 = ptr_c; | |||
| ptr_c += tail_m * ldc; | |||
| n_count = n; | |||
| lda = remain_k2; | |||
| ldb = 32; | |||
| TCONF(cfg, tail_m, 16, remain_k2); | |||
| for (; n_count > 15; n_count -= 16) { | |||
| ptr_b0 = ptr_b + 16 * k32; | |||
| LOAD_C(0, 0); | |||
| LOAD_A(0, x); | |||
| LOAD_B(x, 0); | |||
| MATMUL(0, 0); | |||
| STORE_C(0, 0); | |||
| ptr_b += 16 * k; | |||
| ptr_c00 += 16; | |||
| } | |||
| if (n_count > 0) { | |||
| int tail_n = (n_count > 16) ? 16: n_count; | |||
| ptr_b0 = ptr_b + tail_n * k32; | |||
| ldb = 2 * tail_n; | |||
| TCONF(cfg, tail_m, tail_n, remain_k2); | |||
| LOAD_C(0, 0); | |||
| LOAD_A(0, x); | |||
| LOAD_B(x, 0); | |||
| MATMUL(0, 0); | |||
| STORE_C(0, 0); | |||
| } | |||
| } | |||
| if (n_count > 0) { | |||
| int tail_n = (n_count > 16) ? 16: n_count; | |||
| __mmask16 bmask = (1UL << tail_n) - 1; | |||
| ptr_b0 = ptr_b + tail_n * k2; | |||
| TCONF(cfg, tail_m, tail_n, 2); | |||
| LOAD_C(0, 0); | |||
| MASK_LOAD_A_TAIL(0, x); | |||
| MASK_LOAD_B_TAIL(x, 0); | |||
| MATMUL(0, 0); | |||
| STORE_C(0, 0); | |||
| } else { // k%32 = 1 | |||
| for (; m_count > 0; m_count -= 16) { | |||
| int tail_m = (m_count > 16) ? 16: m_count; | |||
| __mmask16 amask = (1UL << tail_m) - 1; | |||
| ptr_a0 = ptr_a + tail_m * k2; | |||
| ptr_a += tail_m * k; | |||
| ptr_b = B; | |||
| ptr_c00 = ptr_c; | |||
| ptr_c += tail_m * ldc; | |||
| n_count = n; | |||
| TCONF(cfg, tail_m, 16, 2); | |||
| for (; n_count > 15; n_count -= 16) { | |||
| ptr_b0 = ptr_b + 16 * k2; | |||
| LOAD_C(0, 0); | |||
| MASK_LOAD_A_TAIL(0, x); | |||
| LOAD_B_TAIL(x, 0); | |||
| MATMUL(0, 0); | |||
| STORE_C(0, 0); | |||
| ptr_b += 16 * k; | |||
| ptr_c00 += 16; | |||
| } | |||
| if (n_count > 0) { | |||
| int tail_n = (n_count > 16) ? 16: n_count; | |||
| __mmask16 bmask = (1UL << tail_n) - 1; | |||
| ptr_b0 = ptr_b + tail_n * k2; | |||
| TCONF(cfg, tail_m, tail_n, 2); | |||
| LOAD_C(0, 0); | |||
| MASK_LOAD_A_TAIL(0, x); | |||
| MASK_LOAD_B_TAIL(x, 0); | |||
| MATMUL(0, 0); | |||
| STORE_C(0, 0); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||