|
|
|
@@ -82,6 +82,12 @@ typedef struct { |
|
|
|
_mm512_storeu_epi16(tail_a + 16 * M, zmm); \ |
|
|
|
_tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ |
|
|
|
} |
|
|
|
#define MASK_LOAD_A_TAIL(M, N) {\ |
|
|
|
__m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \ |
|
|
|
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \ |
|
|
|
_mm512_storeu_epi16(tail_a + 16 * M, zmm); \ |
|
|
|
_tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ |
|
|
|
} |
|
|
|
#define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2) |
|
|
|
#define LOAD_B_TAIL(M, N) {\ |
|
|
|
__m256i ymm = _mm256_loadu_epi16(ptr_b##N); \ |
|
|
|
@@ -111,7 +117,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA |
|
|
|
A = iB; |
|
|
|
B = iA; |
|
|
|
|
|
|
|
printf("kernel: m %d, n %d, k %d, ldc: %d\n", m, n, k, ldc); |
|
|
|
IFLOAT *ptr_a = A, *ptr_b = B; |
|
|
|
IFLOAT *ptr_b0, *ptr_b1; |
|
|
|
IFLOAT *ptr_a0, *ptr_a1; |
|
|
|
@@ -279,5 +284,133 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA |
|
|
|
} |
|
|
|
ptr_a += 32 * k; |
|
|
|
} |
|
|
|
for (; m_count > 0; m_count -= 16) { |
|
|
|
// process at most 16 m at a time |
|
|
|
int tail_m = (m_count > 16) ? 16: m_count; |
|
|
|
__mmask16 amask = (1UL << tail_m) - 1; |
|
|
|
|
|
|
|
ptr_b = B; |
|
|
|
|
|
|
|
ptr_c00 = ptr_c; |
|
|
|
ptr_c01 = ptr_c00 + 16; |
|
|
|
ptr_c += tail_m * ldc; |
|
|
|
n_count = n; |
|
|
|
for (; n_count > 31; n_count -= 32) { |
|
|
|
ptr_a0 = ptr_a; |
|
|
|
|
|
|
|
ptr_b0 = ptr_b; |
|
|
|
ptr_b1 = ptr_b + 16 * k; |
|
|
|
ptr_b += 32 * k; |
|
|
|
|
|
|
|
lda = 32; |
|
|
|
ldb = 32; |
|
|
|
TCONF(cfg, tail_m, 16, 32); |
|
|
|
LOAD_C(0, 0); LOAD_C(0, 1); |
|
|
|
k_count = k; |
|
|
|
for (; k_count > 31; k_count -= 32) { |
|
|
|
LOAD_A(0, x); |
|
|
|
ptr_a0 += tail_m * 32; |
|
|
|
LOAD_B(x, 0); LOAD_B(x, 1); |
|
|
|
ptr_b0 += 16 * 32; |
|
|
|
ptr_b1 += 16 * 32; |
|
|
|
|
|
|
|
MATMUL(0, 0); MATMUL(0, 1); |
|
|
|
} |
|
|
|
STORE_C(0, 0); STORE_C(0, 1); |
|
|
|
if (k_count > 1) { |
|
|
|
/* still have more than 2*k */ |
|
|
|
int remain_k2 = k_count & ~1; |
|
|
|
k_count -= remain_k2; |
|
|
|
lda = remain_k2; |
|
|
|
TCONF(cfg, tail_m, 16, remain_k2); |
|
|
|
/* reconfig will clear all tiles, |
|
|
|
* need to store/load again |
|
|
|
*/ |
|
|
|
LOAD_C(0, 0); LOAD_C(0, 1); |
|
|
|
|
|
|
|
LOAD_A(0, x); |
|
|
|
ptr_a0 += tail_m * remain_k2; |
|
|
|
LOAD_B(x, 0); LOAD_B(x, 1); |
|
|
|
ptr_b0 += 16 * remain_k2; |
|
|
|
ptr_b1 += 16 * remain_k2; |
|
|
|
|
|
|
|
MATMUL(0, 0); MATMUL(0, 1); |
|
|
|
|
|
|
|
STORE_C(0, 0); STORE_C(0, 1); |
|
|
|
} |
|
|
|
if (k_count > 0) { |
|
|
|
/* still have odd tail k, need to transform into 2*k */ |
|
|
|
TCONF(cfg, tail_m, 16, 2); |
|
|
|
|
|
|
|
LOAD_C(0, 0); LOAD_C(0, 1); |
|
|
|
|
|
|
|
MASK_LOAD_A_TAIL(0, x); |
|
|
|
LOAD_B_TAIL(x, 0); LOAD_B_TAIL(x, 1); |
|
|
|
|
|
|
|
MATMUL(0, 0); MATMUL(0, 1); |
|
|
|
|
|
|
|
STORE_C(0, 0); STORE_C(0, 1); |
|
|
|
} |
|
|
|
ptr_c00 += 32; |
|
|
|
ptr_c01 += 32; |
|
|
|
} |
|
|
|
for (; n_count > 0; n_count -= 16) { |
|
|
|
int tail_n = (n_count > 16) ? 16: n_count; |
|
|
|
__mmask16 bmask = (1UL << tail_n) - 1; |
|
|
|
ptr_a0 = ptr_a; |
|
|
|
|
|
|
|
ptr_b0 = ptr_b; |
|
|
|
ptr_b += tail_n * k; |
|
|
|
|
|
|
|
lda = 32; |
|
|
|
ldb = 2 * tail_n; |
|
|
|
TCONF(cfg, tail_m, tail_n, 32); |
|
|
|
LOAD_C(0, 0); |
|
|
|
k_count = k; |
|
|
|
for (; k_count > 31; k_count -= 32) { |
|
|
|
LOAD_A(0, x); |
|
|
|
ptr_a0 += tail_m * 32; |
|
|
|
LOAD_B(x, 0); |
|
|
|
ptr_b0 += tail_n * 32; |
|
|
|
|
|
|
|
MATMUL(0, 0); |
|
|
|
} |
|
|
|
STORE_C(0, 0); |
|
|
|
if (k_count > 1) { |
|
|
|
/* still have more than 2*k */ |
|
|
|
int remain_k2 = k_count & ~1; |
|
|
|
k_count -= remain_k2; |
|
|
|
lda = remain_k2; |
|
|
|
TCONF(cfg, tail_m, tail_n, remain_k2); |
|
|
|
/* reconfig will clear all tiles, |
|
|
|
* need to store/load again |
|
|
|
*/ |
|
|
|
LOAD_C(0, 0); |
|
|
|
|
|
|
|
LOAD_A(0, x); |
|
|
|
ptr_a0 += tail_m * remain_k2; |
|
|
|
LOAD_B(x, 0); |
|
|
|
ptr_b0 += tail_n * remain_k2; |
|
|
|
|
|
|
|
MATMUL(0, 0); |
|
|
|
|
|
|
|
STORE_C(0, 0); |
|
|
|
} |
|
|
|
if (k_count > 0) { |
|
|
|
/* still have odd tail k, need to transform into 2*k */ |
|
|
|
TCONF(cfg, tail_m, tail_n, 2); |
|
|
|
|
|
|
|
LOAD_C(0, 0); |
|
|
|
|
|
|
|
MASK_LOAD_A_TAIL(0, x); |
|
|
|
MASK_LOAD_B_TAIL(x, 0); |
|
|
|
MATMUL(0, 0); |
|
|
|
|
|
|
|
STORE_C(0, 0); |
|
|
|
} |
|
|
|
ptr_c00 += tail_n; |
|
|
|
} |
|
|
|
ptr_a += tail_m * k; |
|
|
|
} |
|
|
|
return 0; |
|
|
|
} |