| @@ -64,6 +64,28 @@ typedef struct { | |||||
| cfg.tile_colsb[7] = n * 4; \ | cfg.tile_colsb[7] = n * 4; \ | ||||
| _tile_loadconfig(&cfg); | _tile_loadconfig(&cfg); | ||||
| /* CONFIG for handling k2 and odd tail at the same time | |||||
| * tile0 -- A (m x 2k) | |||||
| * tile1 -- A (m x 1) | |||||
| * tile2 -- B (2k x n) | |||||
| * tile3 -- B (1 x n) | |||||
| * tile4 -- C (m x n) | |||||
| */ | |||||
| #define TCONF_TAIL(cfg, m, n, k2) \ | |||||
| memset(&cfg, 0, sizeof(tilecfg)); \ | |||||
| cfg.palette_id = 1; \ | |||||
| cfg.tile_rows[0] = m; \ | |||||
| cfg.tile_rows[1] = m; \ | |||||
| cfg.tile_rows[2] = k2>>1; \ | |||||
| cfg.tile_rows[3] = 1; \ | |||||
| cfg.tile_rows[4] = m; \ | |||||
| cfg.tile_colsb[0] = k2<<1; \ | |||||
| cfg.tile_colsb[1] = 4; \ | |||||
| cfg.tile_colsb[2] = n * 4; \ | |||||
| cfg.tile_colsb[3] = n * 4; \ | |||||
| cfg.tile_colsb[4] = n * 4; \ | |||||
| _tile_loadconfig(&cfg); | |||||
| #define T_A0 0 | #define T_A0 0 | ||||
| #define T_A1 1 | #define T_A1 1 | ||||
| #define T_B0 2 | #define T_B0 2 | ||||
| @@ -104,6 +126,7 @@ typedef struct { | |||||
| #define LOAD_C(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4) | #define LOAD_C(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4) | ||||
| #define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N) | #define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N) | ||||
| #define MATMUL_TAIL(M, N) _tile_dpbf16ps(T_C00, T_A##M, T_B##N) | |||||
| #define STORE_C(M, N) _tile_stored(T_C##M##N, ptr_c##M##N, ldc * 4) | #define STORE_C(M, N) _tile_stored(T_C##M##N, ptr_c##M##N, ldc * 4) | ||||
| @@ -275,86 +298,123 @@ tail_k: | |||||
| // process for k < 32 | // process for k < 32 | ||||
| BLASLONG k32 = k & ~31; | BLASLONG k32 = k & ~31; | ||||
| BLASLONG k2 = k & ~1; | BLASLONG k2 = k & ~1; | ||||
| int remain_k2 = k2 - k32; | |||||
| if (remain_k2 > 0) { | |||||
| if (k32 != k) { | |||||
| int remain_k2 = k2 - k32; | |||||
| m_count = m; | m_count = m; | ||||
| ptr_a = A; | ptr_a = A; | ||||
| ptr_c = C; | ptr_c = C; | ||||
| for (; m_count > 0; m_count -= 16) { | |||||
| int tail_m = (m_count > 16) ? 16: m_count; | |||||
| __mmask16 amask = (1UL << tail_m) - 1; | |||||
| ptr_a0 = ptr_a + tail_m * k32; | |||||
| ptr_a += tail_m * k; | |||||
| ptr_b = B; | |||||
| ptr_c00 = ptr_c; | |||||
| ptr_c += tail_m * ldc; | |||||
| n_count = n; | |||||
| lda = remain_k2; | |||||
| ldb = 32; | |||||
| TCONF(cfg, tail_m, 16, remain_k2); | |||||
| for (; n_count > 15; n_count -= 16) { | |||||
| ptr_b0 = ptr_b + 16 * k32; | |||||
| LOAD_C(0, 0); | |||||
| LOAD_A(0, x); | |||||
| LOAD_B(x, 0); | |||||
| MATMUL(0, 0); | |||||
| STORE_C(0, 0); | |||||
| ptr_b += 16 * k; | |||||
| ptr_c00 += 16; | |||||
| } | |||||
| if (n_count > 0) { | |||||
| int tail_n = (n_count > 16) ? 16: n_count; | |||||
| __mmask16 bmask = (1UL << tail_n) - 1; | |||||
| ptr_b0 = ptr_b + tail_n * k32; | |||||
| ldb = 2 * tail_n; | |||||
| TCONF(cfg, tail_m, tail_n, remain_k2); | |||||
| LOAD_C(0, 0); | |||||
| LOAD_A(0, x); | |||||
| LOAD_B(x, 0); | |||||
| MATMUL(0, 0); | |||||
| STORE_C(0, 0); | |||||
| if (remain_k2 > 0 && k2 != k) { // k%32 = 2x + 1 (x != 0) | |||||
| for (; m_count > 0; m_count -= 16) { | |||||
| int tail_m = (m_count > 16) ? 16: m_count; | |||||
| __mmask16 amask = (1UL << tail_m) - 1; | |||||
| ptr_a0 = ptr_a + tail_m * k32; | |||||
| ptr_a1 = ptr_a + tail_m * k2; | |||||
| ptr_a += tail_m * k; | |||||
| ptr_b = B; | |||||
| ptr_c00 = ptr_c; | |||||
| ptr_c += tail_m * ldc; | |||||
| n_count = n; | |||||
| lda = remain_k2; | |||||
| ldb = 32; | |||||
| TCONF_TAIL(cfg, tail_m, 16, remain_k2); | |||||
| for (; n_count > 15; n_count -= 16) { | |||||
| ptr_b0 = ptr_b + 16 * k32; | |||||
| ptr_b1 = ptr_b + 16 * k2; | |||||
| LOAD_C(0, 0); | |||||
| LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); | |||||
| LOAD_B(x, 0); LOAD_B_TAIL(x, 1); | |||||
| MATMUL(0, 0); MATMUL_TAIL(1, 1); | |||||
| STORE_C(0, 0); | |||||
| ptr_b += 16 * k; | |||||
| ptr_c00 += 16; | |||||
| } | |||||
| if (n_count > 0) { | |||||
| int tail_n = (n_count > 16) ? 16: n_count; | |||||
| __mmask16 bmask = (1UL << tail_n) - 1; | |||||
| ptr_b0 = ptr_b + tail_n * k32; | |||||
| ptr_b1 = ptr_b + tail_n * k2; | |||||
| ldb = 2 * tail_n; | |||||
| TCONF_TAIL(cfg, tail_m, tail_n, remain_k2); | |||||
| LOAD_C(0, 0); | |||||
| LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); | |||||
| LOAD_B(x, 0); MASK_LOAD_B_TAIL(x, 1); | |||||
| MATMUL(0, 0); MATMUL_TAIL(1, 1); | |||||
| STORE_C(0, 0); | |||||
| } | |||||
| } | } | ||||
| } | |||||
| } | |||||
| if (k2 != k) { | |||||
| m_count = m; | |||||
| ptr_a = A; | |||||
| ptr_c = C; | |||||
| for (; m_count > 0; m_count -= 16) { | |||||
| int tail_m = (m_count > 16) ? 16: m_count; | |||||
| __mmask16 amask = (1UL << tail_m) - 1; | |||||
| ptr_a0 = ptr_a + tail_m * k2; | |||||
| ptr_a += tail_m * k; | |||||
| ptr_b = B; | |||||
| ptr_c00 = ptr_c; | |||||
| ptr_c += tail_m * ldc; | |||||
| n_count = n; | |||||
| TCONF(cfg, tail_m, 16, 2); | |||||
| for (; n_count > 15; n_count -= 16) { | |||||
| ptr_b0 = ptr_b + 16 * k2; | |||||
| LOAD_C(0, 0); | |||||
| MASK_LOAD_A_TAIL(0, x); | |||||
| LOAD_B_TAIL(x, 0); | |||||
| MATMUL(0, 0); | |||||
| STORE_C(0, 0); | |||||
| ptr_b += 16 * k; | |||||
| ptr_c00 += 16; | |||||
| } else if (remain_k2 > 0) { // k%32 = 2x | |||||
| for (; m_count > 0; m_count -= 16) { | |||||
| int tail_m = (m_count > 16) ? 16: m_count; | |||||
| ptr_a0 = ptr_a + tail_m * k32; | |||||
| ptr_a += tail_m * k; | |||||
| ptr_b = B; | |||||
| ptr_c00 = ptr_c; | |||||
| ptr_c += tail_m * ldc; | |||||
| n_count = n; | |||||
| lda = remain_k2; | |||||
| ldb = 32; | |||||
| TCONF(cfg, tail_m, 16, remain_k2); | |||||
| for (; n_count > 15; n_count -= 16) { | |||||
| ptr_b0 = ptr_b + 16 * k32; | |||||
| LOAD_C(0, 0); | |||||
| LOAD_A(0, x); | |||||
| LOAD_B(x, 0); | |||||
| MATMUL(0, 0); | |||||
| STORE_C(0, 0); | |||||
| ptr_b += 16 * k; | |||||
| ptr_c00 += 16; | |||||
| } | |||||
| if (n_count > 0) { | |||||
| int tail_n = (n_count > 16) ? 16: n_count; | |||||
| ptr_b0 = ptr_b + tail_n * k32; | |||||
| ldb = 2 * tail_n; | |||||
| TCONF(cfg, tail_m, tail_n, remain_k2); | |||||
| LOAD_C(0, 0); | |||||
| LOAD_A(0, x); | |||||
| LOAD_B(x, 0); | |||||
| MATMUL(0, 0); | |||||
| STORE_C(0, 0); | |||||
| } | |||||
| } | } | ||||
| if (n_count > 0) { | |||||
| int tail_n = (n_count > 16) ? 16: n_count; | |||||
| __mmask16 bmask = (1UL << tail_n) - 1; | |||||
| ptr_b0 = ptr_b + tail_n * k2; | |||||
| TCONF(cfg, tail_m, tail_n, 2); | |||||
| LOAD_C(0, 0); | |||||
| MASK_LOAD_A_TAIL(0, x); | |||||
| MASK_LOAD_B_TAIL(x, 0); | |||||
| MATMUL(0, 0); | |||||
| STORE_C(0, 0); | |||||
| } else { // k%32 = 1 | |||||
| for (; m_count > 0; m_count -= 16) { | |||||
| int tail_m = (m_count > 16) ? 16: m_count; | |||||
| __mmask16 amask = (1UL << tail_m) - 1; | |||||
| ptr_a0 = ptr_a + tail_m * k2; | |||||
| ptr_a += tail_m * k; | |||||
| ptr_b = B; | |||||
| ptr_c00 = ptr_c; | |||||
| ptr_c += tail_m * ldc; | |||||
| n_count = n; | |||||
| TCONF(cfg, tail_m, 16, 2); | |||||
| for (; n_count > 15; n_count -= 16) { | |||||
| ptr_b0 = ptr_b + 16 * k2; | |||||
| LOAD_C(0, 0); | |||||
| MASK_LOAD_A_TAIL(0, x); | |||||
| LOAD_B_TAIL(x, 0); | |||||
| MATMUL(0, 0); | |||||
| STORE_C(0, 0); | |||||
| ptr_b += 16 * k; | |||||
| ptr_c00 += 16; | |||||
| } | |||||
| if (n_count > 0) { | |||||
| int tail_n = (n_count > 16) ? 16: n_count; | |||||
| __mmask16 bmask = (1UL << tail_n) - 1; | |||||
| ptr_b0 = ptr_b + tail_n * k2; | |||||
| TCONF(cfg, tail_m, tail_n, 2); | |||||
| LOAD_C(0, 0); | |||||
| MASK_LOAD_A_TAIL(0, x); | |||||
| MASK_LOAD_B_TAIL(x, 0); | |||||
| MATMUL(0, 0); | |||||
| STORE_C(0, 0); | |||||
| } | |||||
| } | } | ||||
| } | |||||
| } | |||||
| } | } | ||||
| return 0; | return 0; | ||||
| } | } | ||||