sbgemm: add AMX-BF16 based kernel for Sapphire Rapidstags/v0.3.19
| @@ -1 +1,14 @@ | |||
| include $(KERNELDIR)/KERNEL.COOPERLAKE | |||
| SBGEMM_SMALL_M_PERMIT = sbgemm_small_kernel_permit_spr.c | |||
| SBGEMM_BETA = sgemm_beta_skylakex.c | |||
| SBGEMMKERNEL = sbgemm_kernel_16x16_spr.c | |||
| SBGEMMINCOPY = sbgemm_ncopy_16_cooperlake.c | |||
| SBGEMMITCOPY = sbgemm_tcopy_16_cooperlake.c | |||
| SBGEMMONCOPY = sbgemm_oncopy_16_spr.c | |||
| SBGEMMOTCOPY = sbgemm_otcopy_16_spr.c | |||
| SBGEMMINCOPYOBJ = sbgemm_incopy$(TSUFFIX).$(SUFFIX) | |||
| SBGEMMITCOPYOBJ = sbgemm_itcopy$(TSUFFIX).$(SUFFIX) | |||
| SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX) | |||
| SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX) | |||
| @@ -0,0 +1,50 @@ | |||
| /*************************************************************************** | |||
| * Copyright (c) 2021, The OpenBLAS Project | |||
| * All rights reserved. | |||
| * Redistribution and use in source and binary forms, with or without | |||
| * modification, are permitted provided that the following conditions are | |||
| * met: | |||
| * 1. Redistributions of source code must retain the above copyright | |||
| * notice, this list of conditions and the following disclaimer. | |||
| * 2. Redistributions in binary form must reproduce the above copyright | |||
| * notice, this list of conditions and the following disclaimer in | |||
| * the documentation and/or other materials provided with the | |||
| * distribution. | |||
| * 3. Neither the name of the OpenBLAS project nor the names of | |||
| * its contributors may be used to endorse or promote products | |||
| * derived from this software without specific prior written permission. | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
| * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
| * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
| * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
| * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
| * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
| * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| * *****************************************************************************/ | |||
| #include "common.h" | |||
| #define ALPHA_ONE | |||
| #include "sbgemm_kernel_16x16_spr_tmpl.c" | |||
| #undef ALPHA_ONE | |||
| #include "sbgemm_kernel_16x16_spr_tmpl.c" | |||
| int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOAT * iB, FLOAT * C, BLASLONG ldc) | |||
| { | |||
| /* transport to Row Major matrix for AMX requirement */ | |||
| BLASLONG m, n; | |||
| IFLOAT *A, *B; | |||
| m = in; | |||
| n = im; | |||
| A = iB; | |||
| B = iA; | |||
| if (alpha == 1.0f) | |||
| return sbgemm_kernel_spr_alpha_one(m, n, k, alpha, A, B, C, ldc); | |||
| else | |||
| return sbgemm_kernel_spr_alpha(m, n, k, alpha, A, B, C, ldc); | |||
| } | |||
| @@ -0,0 +1,530 @@ | |||
| /*************************************************************************** | |||
| * Copyright (c) 2021, The OpenBLAS Project | |||
| * All rights reserved. | |||
| * Redistribution and use in source and binary forms, with or without | |||
| * modification, are permitted provided that the following conditions are | |||
| * met: | |||
| * 1. Redistributions of source code must retain the above copyright | |||
| * notice, this list of conditions and the following disclaimer. | |||
| * 2. Redistributions in binary form must reproduce the above copyright | |||
| * notice, this list of conditions and the following disclaimer in | |||
| * the documentation and/or other materials provided with the | |||
| * distribution. | |||
| * 3. Neither the name of the OpenBLAS project nor the names of | |||
| * its contributors may be used to endorse or promote products | |||
| * derived from this software without specific prior written permission. | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
| * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
| * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
| * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
| * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
| * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
| * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| * *****************************************************************************/ | |||
| #include <immintrin.h> | |||
| #include <string.h> | |||
| #include "common.h" | |||
| #ifndef SBGEMM_KERNEL_SPR | |||
| #define SBGEMM_KERNEL_SPR | |||
| typedef struct { | |||
| char palette_id; | |||
| char start_row; | |||
| char dummy0[14]; // bytes 2-15 reserved, must be zero | |||
| short tile_colsb[8]; | |||
| char dummy1[16]; // bytes 32-47 reserved, must be zero | |||
| char tile_rows[8]; | |||
| char dummy2[16]; // bytes 56-63 reserved, must be zero | |||
| } tilecfg; | |||
| /* tile0/tile1 -- A (m x 2k) | |||
| * tile2/tile3 -- B (2k x n) | |||
| * tile4-7 -- C (m x n) | |||
| */ | |||
| #define TCONF(cfg, m, n, k2) \ | |||
| memset(&cfg, 0, sizeof(tilecfg)); \ | |||
| cfg.palette_id = 1; \ | |||
| cfg.tile_rows[0] = m; \ | |||
| cfg.tile_rows[1] = m; \ | |||
| cfg.tile_rows[2] = k2>>1; \ | |||
| cfg.tile_rows[3] = k2>>1; \ | |||
| cfg.tile_rows[4] = m; \ | |||
| cfg.tile_rows[5] = m; \ | |||
| cfg.tile_rows[6] = m; \ | |||
| cfg.tile_rows[7] = m; \ | |||
| cfg.tile_colsb[0] = k2<<1; \ | |||
| cfg.tile_colsb[1] = k2<<1; \ | |||
| cfg.tile_colsb[2] = n * 4; \ | |||
| cfg.tile_colsb[3] = n * 4; \ | |||
| cfg.tile_colsb[4] = n * 4; \ | |||
| cfg.tile_colsb[5] = n * 4; \ | |||
| cfg.tile_colsb[6] = n * 4; \ | |||
| cfg.tile_colsb[7] = n * 4; \ | |||
| _tile_loadconfig(&cfg); | |||
| /* CONFIG for handling k2 and odd tail at the same time | |||
| * tile0 -- A (m x 2k) | |||
| * tile1 -- A (m x 1) | |||
| * tile2 -- B (2k x n) | |||
| * tile3 -- B (1 x n) | |||
| * tile4 -- C (m x n) | |||
| */ | |||
| #define TCONF_TAIL(cfg, m, n, k2) \ | |||
| memset(&cfg, 0, sizeof(tilecfg)); \ | |||
| cfg.palette_id = 1; \ | |||
| cfg.tile_rows[0] = m; \ | |||
| cfg.tile_rows[1] = m; \ | |||
| cfg.tile_rows[2] = k2>>1; \ | |||
| cfg.tile_rows[3] = 1; \ | |||
| cfg.tile_rows[4] = m; \ | |||
| cfg.tile_colsb[0] = k2<<1; \ | |||
| cfg.tile_colsb[1] = 4; \ | |||
| cfg.tile_colsb[2] = n * 4; \ | |||
| cfg.tile_colsb[3] = n * 4; \ | |||
| cfg.tile_colsb[4] = n * 4; \ | |||
| _tile_loadconfig(&cfg); | |||
| #define T_A0 0 | |||
| #define T_A1 1 | |||
| #define T_B0 2 | |||
| #define T_B1 3 | |||
| #define T_C00 4 | |||
| #define T_C01 5 | |||
| #define T_C10 6 | |||
| #define T_C11 7 | |||
| // FIXME: gcc11 seem have problem in tile load/store address calc, | |||
| // need to multiply with element size (2 or 4) here. | |||
| #define LOAD_A(M, N) _tile_loadd(T_A##M, ptr_a##M, lda * 2) | |||
| #define LOAD_A_TAIL(M, N) {\ | |||
| __m256i ymm = _mm256_loadu_epi16(ptr_a##M); \ | |||
| __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ | |||
| _mm512_storeu_epi16(tail_a + 16 * M, zmm); \ | |||
| _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ | |||
| } | |||
| #define MASK_LOAD_A_TAIL(M, N) {\ | |||
| __m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \ | |||
| __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ | |||
| _mm512_storeu_epi16(tail_a + 16 * M, zmm); \ | |||
| _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ | |||
| } | |||
| #define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2) | |||
| #define LOAD_B_TAIL(M, N) {\ | |||
| __m256i ymm = _mm256_loadu_epi16(ptr_b##N); \ | |||
| __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ | |||
| _mm512_storeu_epi16(tail_b + 16 * N, zmm); \ | |||
| _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \ | |||
| } | |||
| #define MASK_LOAD_B_TAIL(M, N) {\ | |||
| __m256i ymm = _mm256_maskz_loadu_epi16(bmask, ptr_b##N); \ | |||
| __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ | |||
| _mm512_storeu_epi16(tail_b + 16 * N, zmm); \ | |||
| _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \ | |||
| } | |||
| #define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N) | |||
| #define MATMUL_TAIL(M, N) _tile_dpbf16ps(T_C00, T_A##M, T_B##N) | |||
| #define STORE_C(M, N) _tile_stored(T_C##M##N, ptr_c##M##N, ldc * 4) | |||
| #define LOAD_C_F(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4) | |||
| #endif // end of SBGEMM_KERNEL_SPR | |||
| #ifdef ALPHA_ONE | |||
| #undef LOAD_C | |||
| #define LOAD_C(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4) | |||
| #else | |||
| #undef LOAD_C | |||
| #define LOAD_C(M, N) _tile_zero(T_C##M##N) | |||
| #define ALPHA_STORE(N) \ | |||
| __m512 zmm_d##N = _mm512_loadu_ps(dst##N + noffset); \ | |||
| __m512 zmm_s##N = _mm512_loadu_ps(src##N + noffset); \ | |||
| zmm_d##N = _mm512_fmadd_ps(alpha_512, zmm_s##N, zmm_d##N); \ | |||
| _mm512_storeu_ps(dst##N + noffset, zmm_d##N); | |||
| #define MASK_APLPHA_STORE(N) \ | |||
| __m512 zmm_d##N = _mm512_maskz_loadu_ps(mask, dst##N + noffset); \ | |||
| __m512 zmm_s##N = _mm512_maskz_loadu_ps(mask, src##N + noffset); \ | |||
| zmm_d##N = _mm512_fmadd_ps(alpha_512, zmm_s##N, zmm_d##N); \ | |||
| _mm512_mask_storeu_ps(dst##N + noffset, mask, zmm_d##N); | |||
| #endif // end of ALPHA_ONE | |||
| #ifdef ALPHA_ONE | |||
| int sbgemm_kernel_spr_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc) | |||
| #else | |||
| int sbgemm_kernel_spr_alpha(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc) | |||
| #endif | |||
| { | |||
| /* Row Major matrix for AMX requirement */ | |||
| IFLOAT *ptr_a = A, *ptr_b = B; | |||
| IFLOAT *ptr_b0, *ptr_b1; | |||
| IFLOAT *ptr_a0, *ptr_a1; | |||
| FLOAT *ptr_c = C; | |||
| FLOAT *ptr_c00, *ptr_c01, *ptr_c10, *ptr_c11; | |||
| BLASLONG lda, ldb; | |||
| BLASLONG m_count = m; | |||
| BLASLONG n_count, k_count; | |||
| #ifndef ALPHA_ONE | |||
| // make sure each row is 64 bytes aligned | |||
| BLASLONG cn = (n & 31) ? (n & ~31) + 32 : n; | |||
| FLOAT *raw_tmp_c; | |||
| if (k < 32) { | |||
| // only need to zero buff in this situation | |||
| raw_tmp_c = (FLOAT *)calloc(1, sizeof(FLOAT) * m * cn + 64); | |||
| } else { | |||
| raw_tmp_c = (FLOAT *)malloc(sizeof(FLOAT) * m * cn + 64); | |||
| } | |||
| // align buf to 64 byte boundary | |||
| FLOAT *tmp_c = (FLOAT *)(((uintptr_t) raw_tmp_c + 63) & ~(uintptr_t)63); | |||
| ptr_c = tmp_c; | |||
| BLASLONG ldc_o = ldc; | |||
| ldc = cn; | |||
| #endif | |||
| IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64))); | |||
| IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64))); | |||
| tilecfg cfg; | |||
| if (k > 31) { | |||
| for (; m_count > 31; m_count -= 32) { | |||
| ptr_b = B; | |||
| ptr_c00 = ptr_c; | |||
| ptr_c01 = ptr_c00 + 16; | |||
| ptr_c10 = ptr_c + 16 * ldc; | |||
| ptr_c11 = ptr_c10 + 16; | |||
| ptr_c += 32 * ldc; | |||
| n_count = n; | |||
| TCONF(cfg, 16, 16, 32); | |||
| for (; n_count > 31; n_count -= 32) { | |||
| ptr_a0 = ptr_a; | |||
| ptr_a1 = ptr_a + 16 * k; | |||
| ptr_b0 = ptr_b; | |||
| ptr_b1 = ptr_b + 16 * k; | |||
| ptr_b += 32 * k; | |||
| lda = 32; | |||
| ldb = 32; | |||
| LOAD_C(0, 0); LOAD_C(0, 1); | |||
| LOAD_C(1, 0); LOAD_C(1, 1); | |||
| k_count = k; | |||
| for (; k_count > 31; k_count -= 32) { | |||
| LOAD_A(0, x); LOAD_A(1, x); | |||
| ptr_a0 += 16 * 32; | |||
| ptr_a1 += 16 * 32; | |||
| LOAD_B(x, 0); LOAD_B(x, 1); | |||
| ptr_b0 += 16 * 32; | |||
| ptr_b1 += 16 * 32; | |||
| MATMUL(0, 0); MATMUL(0, 1); | |||
| MATMUL(1, 0); MATMUL(1, 1); | |||
| } | |||
| STORE_C(0, 0); STORE_C(0, 1); | |||
| STORE_C(1, 0); STORE_C(1, 1); | |||
| ptr_c00 += 32; | |||
| ptr_c01 += 32; | |||
| ptr_c10 += 32; | |||
| ptr_c11 += 32; | |||
| } | |||
| for (; n_count > 0; n_count -= 16) { | |||
| int tail_n = (n_count > 16) ? 16: n_count; | |||
| ptr_a0 = ptr_a; | |||
| ptr_a1 = ptr_a + 16 * k; | |||
| ptr_b0 = ptr_b; | |||
| ptr_b += tail_n * k; | |||
| lda = 32; | |||
| ldb = 2 * tail_n; | |||
| TCONF(cfg, 16, tail_n, 32); | |||
| LOAD_C(0, 0); | |||
| LOAD_C(1, 0); | |||
| k_count = k; | |||
| for (; k_count > 31; k_count -= 32) { | |||
| LOAD_A(0, x); LOAD_A(1, x); | |||
| ptr_a0 += 16 * 32; | |||
| ptr_a1 += 16 * 32; | |||
| LOAD_B(x, 0); | |||
| ptr_b0 += tail_n * 32; | |||
| MATMUL(0, 0); | |||
| MATMUL(1, 0); | |||
| } | |||
| STORE_C(0, 0); | |||
| STORE_C(1, 0); | |||
| ptr_c00 += tail_n; | |||
| ptr_c10 += tail_n; | |||
| } | |||
| ptr_a += 32 * k; | |||
| } | |||
| for (; m_count > 0; m_count -= 16) { | |||
| // process at most 16 m at a time | |||
| int tail_m = (m_count > 16) ? 16: m_count; | |||
| ptr_b = B; | |||
| ptr_c00 = ptr_c; | |||
| ptr_c01 = ptr_c00 + 16; | |||
| ptr_c += tail_m * ldc; | |||
| n_count = n; | |||
| TCONF(cfg, tail_m, 16, 32); | |||
| for (; n_count > 31; n_count -= 32) { | |||
| ptr_a0 = ptr_a; | |||
| ptr_b0 = ptr_b; | |||
| ptr_b1 = ptr_b + 16 * k; | |||
| ptr_b += 32 * k; | |||
| lda = 32; | |||
| ldb = 32; | |||
| LOAD_C(0, 0); LOAD_C(0, 1); | |||
| k_count = k; | |||
| for (; k_count > 31; k_count -= 32) { | |||
| LOAD_A(0, x); | |||
| ptr_a0 += tail_m * 32; | |||
| LOAD_B(x, 0); LOAD_B(x, 1); | |||
| ptr_b0 += 16 * 32; | |||
| ptr_b1 += 16 * 32; | |||
| MATMUL(0, 0); MATMUL(0, 1); | |||
| } | |||
| STORE_C(0, 0); STORE_C(0, 1); | |||
| ptr_c00 += 32; | |||
| ptr_c01 += 32; | |||
| } | |||
| for (; n_count > 0; n_count -= 16) { | |||
| int tail_n = (n_count > 16) ? 16: n_count; | |||
| ptr_a0 = ptr_a; | |||
| ptr_b0 = ptr_b; | |||
| ptr_b += tail_n * k; | |||
| lda = 32; | |||
| ldb = 2 * tail_n; | |||
| TCONF(cfg, tail_m, tail_n, 32); | |||
| LOAD_C(0, 0); | |||
| k_count = k; | |||
| for (; k_count > 31; k_count -= 32) { | |||
| LOAD_A(0, x); | |||
| ptr_a0 += tail_m * 32; | |||
| LOAD_B(x, 0); | |||
| ptr_b0 += tail_n * 32; | |||
| MATMUL(0, 0); | |||
| } | |||
| STORE_C(0, 0); | |||
| ptr_c00 += tail_n; | |||
| } | |||
| ptr_a += tail_m * k; | |||
| } | |||
| } | |||
| // process for k < 32 | |||
| BLASLONG k32 = k & ~31; | |||
| BLASLONG k2 = k & ~1; | |||
| if (k32 != k) { | |||
| int remain_k2 = k2 - k32; | |||
| m_count = m; | |||
| ptr_a = A; | |||
| #ifndef ALPHA_ONE | |||
| ptr_c = tmp_c; | |||
| #else | |||
| ptr_c = C; | |||
| #endif | |||
| if (remain_k2 > 0 && k2 != k) { // k%32 = 2x + 1 (x != 0) | |||
| for (; m_count > 0; m_count -= 16) { | |||
| int tail_m = (m_count > 16) ? 16: m_count; | |||
| __mmask16 amask = (1UL << tail_m) - 1; | |||
| ptr_a0 = ptr_a + tail_m * k32; | |||
| ptr_a1 = ptr_a + tail_m * k2; | |||
| ptr_a += tail_m * k; | |||
| ptr_b = B; | |||
| ptr_c00 = ptr_c; | |||
| ptr_c += tail_m * ldc; | |||
| n_count = n; | |||
| lda = remain_k2; | |||
| ldb = 32; | |||
| if (n_count > 15) { | |||
| TCONF_TAIL(cfg, tail_m, 16, remain_k2); | |||
| LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); | |||
| for (; n_count > 15; n_count -= 16) { | |||
| ptr_b0 = ptr_b + 16 * k32; | |||
| ptr_b1 = ptr_b + 16 * k2; | |||
| LOAD_C_F(0, 0); | |||
| LOAD_B(x, 0); LOAD_B_TAIL(x, 1); | |||
| MATMUL(0, 0); MATMUL_TAIL(1, 1); | |||
| STORE_C(0, 0); | |||
| ptr_b += 16 * k; | |||
| ptr_c00 += 16; | |||
| } | |||
| } | |||
| if (n_count > 0) { | |||
| int tail_n = (n_count > 16) ? 16: n_count; | |||
| __mmask16 bmask = (1UL << tail_n) - 1; | |||
| ptr_b0 = ptr_b + tail_n * k32; | |||
| ptr_b1 = ptr_b + tail_n * k2; | |||
| ldb = 2 * tail_n; | |||
| TCONF_TAIL(cfg, tail_m, tail_n, remain_k2); | |||
| LOAD_C_F(0, 0); | |||
| LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); | |||
| LOAD_B(x, 0); MASK_LOAD_B_TAIL(x, 1); | |||
| MATMUL(0, 0); MATMUL_TAIL(1, 1); | |||
| STORE_C(0, 0); | |||
| } | |||
| } | |||
| } else if (remain_k2 > 0) { // k%32 = 2x | |||
| for (; m_count > 0; m_count -= 16) { | |||
| int tail_m = (m_count > 16) ? 16: m_count; | |||
| ptr_a0 = ptr_a + tail_m * k32; | |||
| ptr_a += tail_m * k; | |||
| ptr_b = B; | |||
| ptr_c00 = ptr_c; | |||
| ptr_c += tail_m * ldc; | |||
| n_count = n; | |||
| lda = remain_k2; | |||
| ldb = 32; | |||
| if (n_count > 15) { | |||
| TCONF(cfg, tail_m, 16, remain_k2); | |||
| LOAD_A(0, x); | |||
| for (; n_count > 15; n_count -= 16) { | |||
| ptr_b0 = ptr_b + 16 * k32; | |||
| LOAD_C_F(0, 0); | |||
| LOAD_B(x, 0); | |||
| MATMUL(0, 0); | |||
| STORE_C(0, 0); | |||
| ptr_b += 16 * k; | |||
| ptr_c00 += 16; | |||
| } | |||
| } | |||
| if (n_count > 0) { | |||
| int tail_n = (n_count > 16) ? 16: n_count; | |||
| ptr_b0 = ptr_b + tail_n * k32; | |||
| ldb = 2 * tail_n; | |||
| TCONF(cfg, tail_m, tail_n, remain_k2); | |||
| LOAD_C_F(0, 0); | |||
| LOAD_A(0, x); | |||
| LOAD_B(x, 0); | |||
| MATMUL(0, 0); | |||
| STORE_C(0, 0); | |||
| } | |||
| } | |||
| } else { // k%32 = 1 | |||
| for (; m_count > 0; m_count -= 16) { | |||
| int tail_m = (m_count > 16) ? 16: m_count; | |||
| __mmask16 amask = (1UL << tail_m) - 1; | |||
| ptr_a0 = ptr_a + tail_m * k2; | |||
| ptr_a += tail_m * k; | |||
| ptr_b = B; | |||
| ptr_c00 = ptr_c; | |||
| ptr_c += tail_m * ldc; | |||
| n_count = n; | |||
| if (n_count > 15) { | |||
| TCONF(cfg, tail_m, 16, 2); | |||
| MASK_LOAD_A_TAIL(0, x); | |||
| for (; n_count > 15; n_count -= 16) { | |||
| ptr_b0 = ptr_b + 16 * k2; | |||
| LOAD_C_F(0, 0); | |||
| LOAD_B_TAIL(x, 0); | |||
| MATMUL(0, 0); | |||
| STORE_C(0, 0); | |||
| ptr_b += 16 * k; | |||
| ptr_c00 += 16; | |||
| } | |||
| } | |||
| if (n_count > 0) { | |||
| int tail_n = (n_count > 16) ? 16: n_count; | |||
| __mmask16 bmask = (1UL << tail_n) - 1; | |||
| ptr_b0 = ptr_b + tail_n * k2; | |||
| TCONF(cfg, tail_m, tail_n, 2); | |||
| LOAD_C_F(0, 0); | |||
| MASK_LOAD_A_TAIL(0, x); | |||
| MASK_LOAD_B_TAIL(x, 0); | |||
| MATMUL(0, 0); | |||
| STORE_C(0, 0); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #ifndef ALPHA_ONE | |||
| __m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha)); | |||
| BLASLONG n16 = n & ~15; | |||
| BLASLONG noffset; | |||
| FLOAT *src0, *src1, *src2, *src3; | |||
| FLOAT *dst0, *dst1, *dst2, *dst3; | |||
| FLOAT *src = tmp_c; | |||
| FLOAT *dst = C; | |||
| m_count = m; | |||
| for (; m_count > 3; m_count -= 4) { | |||
| src0 = src; | |||
| src1 = src0 + ldc; | |||
| src2 = src1 + ldc; | |||
| src3 = src2 + ldc; | |||
| src += 4 * ldc; | |||
| dst0 = dst; | |||
| dst1 = dst0 + ldc_o; | |||
| dst2 = dst1 + ldc_o; | |||
| dst3 = dst2 + ldc_o; | |||
| dst += 4 * ldc_o; | |||
| noffset = 0; | |||
| for (; noffset < n16; noffset += 16) { | |||
| ALPHA_STORE(0); | |||
| ALPHA_STORE(1); | |||
| ALPHA_STORE(2); | |||
| ALPHA_STORE(3); | |||
| } | |||
| if (noffset < n) { | |||
| __mmask16 mask = (1UL << (n - noffset)) - 1; | |||
| MASK_APLPHA_STORE(0); | |||
| MASK_APLPHA_STORE(1); | |||
| MASK_APLPHA_STORE(2); | |||
| MASK_APLPHA_STORE(3); | |||
| } | |||
| } | |||
| for (; m_count > 1; m_count -= 2) { | |||
| src0 = src; | |||
| src1 = src0 + ldc; | |||
| src += 2 * ldc; | |||
| dst0 = dst; | |||
| dst1 = dst0 + ldc_o; | |||
| dst += 2 * ldc_o; | |||
| noffset = 0; | |||
| for (; noffset < n16; noffset += 16) { | |||
| ALPHA_STORE(0); | |||
| ALPHA_STORE(1); | |||
| } | |||
| if (noffset < n) { | |||
| __mmask16 mask = (1UL << (n - noffset)) - 1; | |||
| MASK_APLPHA_STORE(0); | |||
| MASK_APLPHA_STORE(1); | |||
| } | |||
| } | |||
| for (; m_count > 0; m_count -= 1) { | |||
| src0 = src; | |||
| dst0 = dst; | |||
| noffset = 0; | |||
| for (; noffset < n16; noffset += 16) { | |||
| ALPHA_STORE(0); | |||
| } | |||
| if (noffset < n) { | |||
| __mmask16 mask = (1UL << (n - noffset)) - 1; | |||
| MASK_APLPHA_STORE(0); | |||
| } | |||
| } | |||
| free(raw_tmp_c); | |||
| #endif | |||
| return 0; | |||
| } | |||
| @@ -0,0 +1,128 @@ | |||
| /*************************************************************************** | |||
| * Copyright (c) 2021, The OpenBLAS Project | |||
| * All rights reserved. | |||
| * Redistribution and use in source and binary forms, with or without | |||
| * modification, are permitted provided that the following conditions are | |||
| * met: | |||
| * 1. Redistributions of source code must retain the above copyright | |||
| * notice, this list of conditions and the following disclaimer. | |||
| * 2. Redistributions in binary form must reproduce the above copyright | |||
| * notice, this list of conditions and the following disclaimer in | |||
| * the documentation and/or other materials provided with the | |||
| * distribution. | |||
| * 3. Neither the name of the OpenBLAS project nor the names of | |||
| * its contributors may be used to endorse or promote products | |||
| * derived from this software without specific prior written permission. | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
| * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
| * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
| * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
| * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
| * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
| * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| * *****************************************************************************/ | |||
| #include <immintrin.h> | |||
| #include "common.h" | |||
| typedef struct { | |||
| char palette_id; | |||
| char start_row; | |||
| char dummy0[14]; // bytes 2-15 reserved, must be zero | |||
| short tile_colsb[8]; | |||
| char dummy1[16]; // bytes 32-47 reserved, must be zero | |||
| char tile_rows[8]; | |||
| char dummy2[16]; // bytes 56-63 reserved, must be zero | |||
| } tilecfg; | |||
| #define T_16x32 0 | |||
| #define T_16xm 1 | |||
| #define T_nx32 2 | |||
| #define T_nxm 3 | |||
| #define TCONF(cfg, m, n) \ | |||
| memset(&cfg, 0, sizeof(tilecfg)); \ | |||
| cfg.palette_id = 1; \ | |||
| cfg.tile_rows[T_16x32] = 16; \ | |||
| cfg.tile_colsb[T_16x32] = 64; \ | |||
| if (m) { \ | |||
| cfg.tile_rows[T_16xm] = 16; \ | |||
| cfg.tile_colsb[T_16xm] = m * 2; \ | |||
| } \ | |||
| if (n) { \ | |||
| cfg.tile_rows[T_nx32] = n; \ | |||
| cfg.tile_colsb[T_nx32] = 64; \ | |||
| } \ | |||
| if (m && n) { \ | |||
| cfg.tile_rows[T_nxm] = n; \ | |||
| cfg.tile_colsb[T_nxm] = m * 2; \ | |||
| } \ | |||
| _tile_loadconfig(&cfg); | |||
| int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { | |||
| BLASLONG i, j; | |||
| IFLOAT *aoffset, *boffset; | |||
| IFLOAT *aoffset0; | |||
| aoffset = a; | |||
| boffset = b; | |||
| BLASLONG n16 = n & ~15; | |||
| BLASLONG m32 = m & ~31; | |||
| BLASLONG m2 = m & ~1; | |||
| BLASLONG tail_m = m2 - m32; | |||
| BLASLONG tail_n = n - n16; | |||
| tilecfg cfg; | |||
| TCONF(cfg, tail_m, tail_n); | |||
| for (j = 0; j < n16; j += 16) { | |||
| aoffset0 = aoffset; | |||
| for (i = 0; i < m32; i += 32) { | |||
| _tile_loadd(T_16x32, aoffset0, lda * 2); | |||
| _tile_stored(T_16x32, boffset, 32 * 2); | |||
| aoffset0 += 32; | |||
| boffset += 32 * 16; | |||
| } | |||
| if (i < m2) { | |||
| _tile_loadd(T_16xm, aoffset0, lda * 2); | |||
| _tile_stored(T_16xm, boffset, tail_m * 2); | |||
| aoffset0 += tail_m; | |||
| boffset += tail_m * 16; | |||
| i = m2; | |||
| } | |||
| if (i < m) { | |||
| /* the tail odd k should put alone */ | |||
| for (int ii = 0; ii < 16; ii++) { | |||
| *(boffset + ii) = *(aoffset0 + lda * ii); | |||
| } | |||
| boffset += 16; | |||
| } | |||
| aoffset += 16 * lda; | |||
| } | |||
| if (j < n) { | |||
| aoffset0 = aoffset; | |||
| for (i = 0; i < m32; i += 32) { | |||
| _tile_loadd(T_nx32, aoffset0, lda * 2); | |||
| _tile_stored(T_nx32, boffset, 32 * 2); | |||
| aoffset0 += 32; | |||
| boffset += 32 * tail_n; | |||
| } | |||
| if (i < m2) { | |||
| _tile_loadd(T_nxm, aoffset0, lda * 2); | |||
| _tile_stored(T_nxm, boffset, tail_m * 2); | |||
| aoffset0 += tail_m; | |||
| boffset += tail_m * tail_n; | |||
| } | |||
| if (i < m) { | |||
| for (int ii = 0; ii < tail_n; ii++) { | |||
| *(boffset + ii) = *(aoffset0 + lda * ii); | |||
| } | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| @@ -0,0 +1,302 @@ | |||
| /*************************************************************************** | |||
| * Copyright (c) 2021, The OpenBLAS Project | |||
| * All rights reserved. | |||
| * Redistribution and use in source and binary forms, with or without | |||
| * modification, are permitted provided that the following conditions are | |||
| * met: | |||
| * 1. Redistributions of source code must retain the above copyright | |||
| * notice, this list of conditions and the following disclaimer. | |||
| * 2. Redistributions in binary form must reproduce the above copyright | |||
| * notice, this list of conditions and the following disclaimer in | |||
| * the documentation and/or other materials provided with the | |||
| * distribution. | |||
| * 3. Neither the name of the OpenBLAS project nor the names of | |||
| * its contributors may be used to endorse or promote products | |||
| * derived from this software without specific prior written permission. | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
| * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
| * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
| * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
| * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
| * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
| * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| * *****************************************************************************/ | |||
| #include <immintrin.h> | |||
| #include "common.h" | |||
| #define LOAD_A_8VEC(aptr) \ | |||
| r0 = _mm256_loadu_si256((__m256i *)(aptr + lda*0)); \ | |||
| r1 = _mm256_loadu_si256((__m256i *)(aptr + lda*1)); \ | |||
| r2 = _mm256_loadu_si256((__m256i *)(aptr + lda*2)); \ | |||
| r3 = _mm256_loadu_si256((__m256i *)(aptr + lda*3)); \ | |||
| r4 = _mm256_loadu_si256((__m256i *)(aptr + lda*4)); \ | |||
| r5 = _mm256_loadu_si256((__m256i *)(aptr + lda*5)); \ | |||
| r6 = _mm256_loadu_si256((__m256i *)(aptr + lda*6)); \ | |||
| r7 = _mm256_loadu_si256((__m256i *)(aptr + lda*7)); | |||
| #define MASK_LOAD_A_8VEC(aptr) \ | |||
| r0 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*0)); \ | |||
| r1 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*1)); \ | |||
| r2 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*2)); \ | |||
| r3 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*3)); \ | |||
| r4 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*4)); \ | |||
| r5 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*5)); \ | |||
| r6 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*6)); \ | |||
| r7 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*7)); | |||
| #define SWITCH_LOAD_A_8VEC(aptr, cond) \ | |||
| switch((cond)) { \ | |||
| case 8: r7 = _mm256_loadu_si256((__m256i *)(aptr + lda*7)); \ | |||
| case 7: r6 = _mm256_loadu_si256((__m256i *)(aptr + lda*6)); \ | |||
| case 6: r5 = _mm256_loadu_si256((__m256i *)(aptr + lda*5)); \ | |||
| case 5: r4 = _mm256_loadu_si256((__m256i *)(aptr + lda*4)); \ | |||
| case 4: r3 = _mm256_loadu_si256((__m256i *)(aptr + lda*3)); \ | |||
| case 3: r2 = _mm256_loadu_si256((__m256i *)(aptr + lda*2)); \ | |||
| case 2: r1 = _mm256_loadu_si256((__m256i *)(aptr + lda*1)); \ | |||
| case 1: r0 = _mm256_loadu_si256((__m256i *)(aptr + lda*0)); \ | |||
| } | |||
| #define SWITCH_MASK_LOAD_A_8VEC(aptr, cond) \ | |||
| switch((cond)) { \ | |||
| case 8: r7 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*7)); \ | |||
| case 7: r6 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*6)); \ | |||
| case 6: r5 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*5)); \ | |||
| case 5: r4 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*4)); \ | |||
| case 4: r3 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*3)); \ | |||
| case 3: r2 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*2)); \ | |||
| case 2: r1 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*1)); \ | |||
| case 1: r0 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*0)); \ | |||
| } | |||
| #define REORDER_8x16(t0, t1, t2, t3, t4, t5, t6, t7) \ | |||
| t0 = _mm256_unpacklo_epi16(r0, r1); \ | |||
| t1 = _mm256_unpackhi_epi16(r0, r1); \ | |||
| t2 = _mm256_unpacklo_epi16(r2, r3); \ | |||
| t3 = _mm256_unpackhi_epi16(r2, r3); \ | |||
| t4 = _mm256_unpacklo_epi16(r4, r5); \ | |||
| t5 = _mm256_unpackhi_epi16(r4, r5); \ | |||
| t6 = _mm256_unpacklo_epi16(r6, r7); \ | |||
| t7 = _mm256_unpackhi_epi16(r6, r7); \ | |||
| r0 = _mm256_unpacklo_epi32(t0, t2); \ | |||
| r1 = _mm256_unpacklo_epi32(t1, t3); \ | |||
| r2 = _mm256_unpacklo_epi32(t4, t6); \ | |||
| r3 = _mm256_unpacklo_epi32(t5, t7); \ | |||
| r4 = _mm256_unpackhi_epi32(t0, t2); \ | |||
| r5 = _mm256_unpackhi_epi32(t1, t3); \ | |||
| r6 = _mm256_unpackhi_epi32(t4, t6); \ | |||
| r7 = _mm256_unpackhi_epi32(t5, t7); \ | |||
| t0 = _mm256_unpacklo_epi64(r0, r2); \ | |||
| t1 = _mm256_unpackhi_epi64(r0, r2); \ | |||
| t2 = _mm256_unpacklo_epi64(r4, r6); \ | |||
| t3 = _mm256_unpackhi_epi64(r4, r6); \ | |||
| t4 = _mm256_unpacklo_epi64(r1, r3); \ | |||
| t5 = _mm256_unpackhi_epi64(r1, r3); \ | |||
| t6 = _mm256_unpacklo_epi64(r5, r7); \ | |||
| t7 = _mm256_unpackhi_epi64(r5, r7); | |||
| #define STORE_256_LO(x) \ | |||
| v = _mm256_permute2x128_si256(t0##x, t1##x, 0x20); \ | |||
| _mm256_storeu_si256((__m256i *)(boffset + x*32), v); | |||
| #define STORE_256_HI(x) \ | |||
| v = _mm256_permute2x128_si256(t0##x, t1##x, 0x31); \ | |||
| _mm256_storeu_si256((__m256i *)(boffset + (x + 8)*32), v); | |||
| #define MASK_STORE_256_LO(x) \ | |||
| v = _mm256_permute2x128_si256(t0##x, t1##x, 0x20); \ | |||
| _mm256_mask_storeu_epi16(boffset + x*m_load, mmask, v); | |||
| #define MASK_STORE_256_HI(x) \ | |||
| v = _mm256_permute2x128_si256(t0##x, t1##x, 0x31); \ | |||
| _mm256_mask_storeu_epi16(boffset + (x + 8)*m_load, mmask, v); | |||
| #define STORE_256(x, y) {\ | |||
| __m256i v; \ | |||
| if (x == 0) { STORE_256_LO(y); } \ | |||
| else { STORE_256_HI(y); } \ | |||
| } | |||
| #define MASK_STORE_256(x, y) {\ | |||
| __m256i v; \ | |||
| if (x == 0) { MASK_STORE_256_LO(y); } \ | |||
| else { MASK_STORE_256_HI(y); } \ | |||
| } | |||
| #define SWITCH_STORE_16x(cond, func) \ | |||
| switch((cond)) {\ | |||
| case 15: func(1, 6); \ | |||
| case 14: func(1, 5); \ | |||
| case 13: func(1, 4); \ | |||
| case 12: func(1, 3); \ | |||
| case 11: func(1, 2); \ | |||
| case 10: func(1, 1); \ | |||
| case 9: func(1, 0); \ | |||
| case 8: func(0, 7); \ | |||
| case 7: func(0, 6); \ | |||
| case 6: func(0, 5); \ | |||
| case 5: func(0, 4); \ | |||
| case 4: func(0, 3); \ | |||
| case 3: func(0, 2); \ | |||
| case 2: func(0, 1); \ | |||
| case 1: func(0, 0); \ | |||
| } | |||
| int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { | |||
| IFLOAT *aoffset, *boffset; | |||
| IFLOAT *aoffset00, *aoffset01, *aoffset10, *aoffset11; | |||
| IFLOAT *boffset0; | |||
| __m256i r0, r1, r2, r3, r4, r5, r6, r7; | |||
| __m256i t00, t01, t02, t03, t04, t05, t06, t07; | |||
| __m256i t10, t11, t12, t13, t14, t15, t16, t17; | |||
| aoffset = a; | |||
| boffset = b; | |||
| BLASLONG n_count = n; | |||
| BLASLONG m_count = m; | |||
| for (; n_count > 15; n_count -= 16) { | |||
| aoffset00 = aoffset; | |||
| aoffset01 = aoffset00 + 8 * lda; | |||
| aoffset10 = aoffset01 + 8 * lda; | |||
| aoffset11 = aoffset10 + 8 * lda; | |||
| aoffset += 16; | |||
| m_count = m; | |||
| for (; m_count > 31; m_count -= 32) { | |||
| // first 16 rows | |||
| LOAD_A_8VEC(aoffset00); | |||
| REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); | |||
| LOAD_A_8VEC(aoffset01); | |||
| REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); | |||
| STORE_256(0, 0); STORE_256(0, 1); STORE_256(0, 2); STORE_256(0, 3); | |||
| STORE_256(0, 4); STORE_256(0, 5); STORE_256(0, 6); STORE_256(0, 7); | |||
| STORE_256(1, 0); STORE_256(1, 1); STORE_256(1, 2); STORE_256(1, 3); | |||
| STORE_256(1, 4); STORE_256(1, 5); STORE_256(1, 6); STORE_256(1, 7); | |||
| // last 16 rows | |||
| boffset += 16; | |||
| LOAD_A_8VEC(aoffset10); | |||
| REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); | |||
| LOAD_A_8VEC(aoffset11); | |||
| REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); | |||
| STORE_256(0, 0); STORE_256(0, 1); STORE_256(0, 2); STORE_256(0, 3); | |||
| STORE_256(0, 4); STORE_256(0, 5); STORE_256(0, 6); STORE_256(0, 7); | |||
| STORE_256(1, 0); STORE_256(1, 1); STORE_256(1, 2); STORE_256(1, 3); | |||
| STORE_256(1, 4); STORE_256(1, 5); STORE_256(1, 6); STORE_256(1, 7); | |||
| aoffset00 += 32 * lda; | |||
| aoffset01 += 32 * lda; | |||
| aoffset10 += 32 * lda; | |||
| aoffset11 += 32 * lda; | |||
| boffset += 31 * 16; | |||
| } | |||
| if (m_count > 1) { | |||
| int m_load = m_count & ~1; | |||
| m_count -= m_load; | |||
| __mmask16 mmask; | |||
| SWITCH_LOAD_A_8VEC(aoffset00, m_load > 8 ? 8: m_load); | |||
| REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); | |||
| if (m_load > 8) { | |||
| SWITCH_LOAD_A_8VEC(aoffset01, m_load > 16 ? 8: m_load - 8); | |||
| REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); | |||
| } | |||
| int this_load = m_load > 16 ? 16 : m_load; | |||
| mmask = (1UL << this_load) - 1; | |||
| MASK_STORE_256(0, 0); MASK_STORE_256(0, 1); MASK_STORE_256(0, 2); MASK_STORE_256(0, 3); | |||
| MASK_STORE_256(0, 4); MASK_STORE_256(0, 5); MASK_STORE_256(0, 6); MASK_STORE_256(0, 7); | |||
| MASK_STORE_256(1, 0); MASK_STORE_256(1, 1); MASK_STORE_256(1, 2); MASK_STORE_256(1, 3); | |||
| MASK_STORE_256(1, 4); MASK_STORE_256(1, 5); MASK_STORE_256(1, 6); MASK_STORE_256(1, 7); | |||
| boffset0 = boffset; | |||
| if (m_load > 16) { | |||
| boffset += this_load; | |||
| SWITCH_LOAD_A_8VEC(aoffset10, m_load > 24 ? 8: m_load - 16); | |||
| REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); | |||
| if (m_load > 24) { | |||
| SWITCH_LOAD_A_8VEC(aoffset11, m_load - 24); | |||
| REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); | |||
| } | |||
| this_load = m_load - 16; | |||
| mmask = (1UL << this_load) - 1; | |||
| MASK_STORE_256(0, 0); MASK_STORE_256(0, 1); MASK_STORE_256(0, 2); MASK_STORE_256(0, 3); | |||
| MASK_STORE_256(0, 4); MASK_STORE_256(0, 5); MASK_STORE_256(0, 6); MASK_STORE_256(0, 7); | |||
| MASK_STORE_256(1, 0); MASK_STORE_256(1, 1); MASK_STORE_256(1, 2); MASK_STORE_256(1, 3); | |||
| MASK_STORE_256(1, 4); MASK_STORE_256(1, 5); MASK_STORE_256(1, 6); MASK_STORE_256(1, 7); | |||
| } | |||
| boffset = boffset0 + 16 * m_load; | |||
| aoffset00 += m_load * lda; | |||
| } | |||
| if (m_count > 0) { | |||
| // just copy lask K to B directly | |||
| r0 = _mm256_loadu_si256((__m256i *)(aoffset00)); | |||
| _mm256_storeu_si256((__m256i *)(boffset), r0); | |||
| boffset += 16; | |||
| } | |||
| } | |||
| if (n_count > 0) { | |||
| __mmask16 nmask = (1UL << n_count) - 1; | |||
| aoffset00 = aoffset; | |||
| aoffset01 = aoffset00 + 8 * lda; | |||
| aoffset10 = aoffset01 + 8 * lda; | |||
| aoffset11 = aoffset10 + 8 * lda; | |||
| m_count = m; | |||
| for (; m_count > 31; m_count -= 32) { | |||
| // first 16 rows | |||
| MASK_LOAD_A_8VEC(aoffset00); | |||
| REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); | |||
| MASK_LOAD_A_8VEC(aoffset01); | |||
| REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); | |||
| SWITCH_STORE_16x(n_count, STORE_256); | |||
| // last 16 rows | |||
| boffset0 = boffset; | |||
| boffset += 16; | |||
| MASK_LOAD_A_8VEC(aoffset10); | |||
| REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); | |||
| MASK_LOAD_A_8VEC(aoffset11); | |||
| REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); | |||
| SWITCH_STORE_16x(n_count, STORE_256); | |||
| aoffset00 += 32 * lda; | |||
| aoffset01 += 32 * lda; | |||
| aoffset10 += 32 * lda; | |||
| aoffset11 += 32 * lda; | |||
| boffset = 32 * n_count + boffset0; | |||
| } | |||
| if (m_count > 1) { | |||
| int m_load = m_count & ~1; | |||
| m_count -= m_load; | |||
| __mmask16 mmask; | |||
| SWITCH_MASK_LOAD_A_8VEC(aoffset00, m_load > 8 ? 8: m_load); | |||
| REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); | |||
| if (m_load > 8) { | |||
| SWITCH_MASK_LOAD_A_8VEC(aoffset01, m_load > 16 ? 8: m_load - 8); | |||
| REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); | |||
| } | |||
| int this_load = m_load > 16 ? 16 : m_load; | |||
| mmask = (1UL << this_load) - 1; | |||
| SWITCH_STORE_16x(n_count, MASK_STORE_256); | |||
| boffset0 = boffset; | |||
| if (m_load > 16) { | |||
| boffset += this_load; | |||
| SWITCH_MASK_LOAD_A_8VEC(aoffset10, m_load > 24 ? 8: m_load - 16); | |||
| REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); | |||
| if (m_load > 24) { | |||
| SWITCH_MASK_LOAD_A_8VEC(aoffset11, m_load - 24); | |||
| REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); | |||
| } | |||
| this_load = m_load - 16; | |||
| mmask = (1UL << this_load) - 1; | |||
| SWITCH_STORE_16x(n_count, MASK_STORE_256); | |||
| } | |||
| boffset = boffset0 + n_count * m_load; | |||
| aoffset00 += m_load * lda; | |||
| } | |||
| if (m_count > 0) { | |||
| // just copy lask K to B directly | |||
| r0 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aoffset00)); | |||
| _mm256_mask_storeu_epi16((__m256i *)(boffset), nmask, r0); | |||
| boffset += 16; | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| @@ -0,0 +1,42 @@ | |||
| /*************************************************************************** | |||
| Copyright (c) 2021, The OpenBLAS Project | |||
| All rights reserved. | |||
| Redistribution and use in source and binary forms, with or without | |||
| modification, are permitted provided that the following conditions are | |||
| met: | |||
| 1. Redistributions of source code must retain the above copyright | |||
| notice, this list of conditions and the following disclaimer. | |||
| 2. Redistributions in binary form must reproduce the above copyright | |||
| notice, this list of conditions and the following disclaimer in | |||
| the documentation and/or other materials provided with the | |||
| distribution. | |||
| 3. Neither the name of the OpenBLAS project nor the names of | |||
| its contributors may be used to endorse or promote products | |||
| derived from this software without specific prior written permission. | |||
| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
| LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
| DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
| SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
| CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
| OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
| USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| *****************************************************************************/ | |||
| #include "common.h" | |||
| #include "sbgemm_block_microk_cooperlake.c" | |||
| // Define micro kernels for ALPHA not ONE scenarios | |||
| #undef ONE_ALPHA | |||
| #include "sbgemm_microk_cooperlake_template.c" | |||
| // Define micro kernels for ALPHA as ONE scenarios | |||
| #define ONE_ALPHA 1 | |||
| #include "sbgemm_microk_cooperlake_template.c" | |||
| int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT beta) | |||
| { | |||
| return 0; | |||
| } | |||
| @@ -1771,6 +1771,20 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| #endif | |||
| #define USE_SGEMM_KERNEL_DIRECT 1 | |||
| #undef SBGEMM_DEFAULT_UNROLL_N | |||
| #undef SBGEMM_DEFAULT_UNROLL_M | |||
| #undef SBGEMM_DEFAULT_P | |||
| #undef SBGEMM_DEFAULT_R | |||
| #undef SBGEMM_DEFAULT_Q | |||
| // FIXME: actually UNROLL_M = UNROLL_N = 16 | |||
| // If M and N is equal, OpenBLAS will reuse OCOPY as ICOPY. | |||
| // But for AMX, they are not the same, set UNROLL_M = 32 to workaround | |||
| #define SBGEMM_DEFAULT_UNROLL_N 16 | |||
| #define SBGEMM_DEFAULT_UNROLL_M 32 | |||
| #define SBGEMM_DEFAULT_P 256 | |||
| #define SBGEMM_DEFAULT_Q 1024 | |||
| #define SBGEMM_DEFAULT_R sbgemm_r | |||
| #ifdef ARCH_X86 | |||
| #define SGEMM_DEFAULT_UNROLL_M 4 | |||