| @@ -31,17 +31,25 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| #define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps() | #define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps() | ||||
| #define LOAD_A_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&A[lda * k + i + (M*16)]) | #define LOAD_A_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&A[lda * k + i + (M*16)]) | ||||
| #define MASK_LOAD_A_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &A[lda * k + i + (M*16)]) | |||||
| #define BROADCAST_LOAD_B_512(M, N) __m512 Bval##N = _mm512_broadcastss_ps(_mm_load_ss(&B[k + ldb * (j+N)])) | #define BROADCAST_LOAD_B_512(M, N) __m512 Bval##N = _mm512_broadcastss_ps(_mm_load_ss(&B[k + ldb * (j+N)])) | ||||
| #define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N) | #define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N) | ||||
| #if defined(B0) | #if defined(B0) | ||||
| #define STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | #define STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | ||||
| _mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N) | _mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N) | ||||
| #define MASK_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||||
| _mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N) | |||||
| #else | #else | ||||
| #define STORE_512(M, N) \ | #define STORE_512(M, N) \ | ||||
| BLASLONG offset##M##N = (j+N)*ldc + i + (M*16); \ | BLASLONG offset##M##N = (j+N)*ldc + i + (M*16); \ | ||||
| result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | ||||
| asm("vfmadd231ps (%1, %2, 4), %3, %0": "+v"(result##M##N):"r"(&C), "r"(offset##M##N), "v"(beta_512)); \ | asm("vfmadd231ps (%1, %2, 4), %3, %0": "+v"(result##M##N):"r"(&C), "r"(offset##M##N), "v"(beta_512)); \ | ||||
| _mm512_storeu_ps(&C[offset##M##N], result##M##N) | _mm512_storeu_ps(&C[offset##M##N], result##M##N) | ||||
| #define MASK_STORE_512(M, N) \ | |||||
| BLASLONG offset##M##N = (j+N)*ldc + i + (M*16); \ | |||||
| result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||||
| asm("vfmadd231ps (%1, %2, 4), %3, %0 %{%4%}": "+v"(result##M##N):"r"(&C), "r"(offset##M##N), "v"(beta_512), "k"(mask)); \ | |||||
| _mm512_mask_storeu_ps(&C[offset##M##N], mask, result##M##N) | |||||
| #endif | #endif | ||||
| #define DECLARE_RESULT_256(M, N) __m256 result##M##N = _mm256_setzero_ps() | #define DECLARE_RESULT_256(M, N) __m256 result##M##N = _mm256_setzero_ps() | ||||
| @@ -241,6 +249,51 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp | |||||
| STORE_512(0, 0); | STORE_512(0, 0); | ||||
| } | } | ||||
| } | } | ||||
| if (M - i > 0) { | |||||
| register __mmask16 mask asm("k1") = (1UL << (M - i)) - 1; | |||||
| for (j = 0; j < n4; j += 4) { | |||||
| DECLARE_RESULT_512(0, 0); | |||||
| DECLARE_RESULT_512(0, 1); | |||||
| DECLARE_RESULT_512(0, 2); | |||||
| DECLARE_RESULT_512(0, 3); | |||||
| for (k = 0; k < K; k++) { | |||||
| MASK_LOAD_A_512(0, x); | |||||
| BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||||
| BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||||
| MATMUL_512(0, 0); | |||||
| MATMUL_512(0, 1); | |||||
| MATMUL_512(0, 2); | |||||
| MATMUL_512(0, 3); | |||||
| } | |||||
| MASK_STORE_512(0, 0); | |||||
| MASK_STORE_512(0, 1); | |||||
| MASK_STORE_512(0, 2); | |||||
| MASK_STORE_512(0, 3); | |||||
| } | |||||
| for (; j < n2; j += 2) { | |||||
| DECLARE_RESULT_512(0, 0); | |||||
| DECLARE_RESULT_512(0, 1); | |||||
| for (k = 0; k < K; k++) { | |||||
| MASK_LOAD_A_512(0, x); | |||||
| BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||||
| MATMUL_512(0, 0); | |||||
| MATMUL_512(0, 1); | |||||
| } | |||||
| MASK_STORE_512(0, 0); | |||||
| MASK_STORE_512(0, 1); | |||||
| } | |||||
| for (; j < N; j++) { | |||||
| DECLARE_RESULT_512(0, 0); | |||||
| for (k = 0; k < K; k++) { | |||||
| MASK_LOAD_A_512(0, x); | |||||
| BROADCAST_LOAD_B_512(x, 0); | |||||
| MATMUL_512(0, 0); | |||||
| } | |||||
| MASK_STORE_512(0, 0); | |||||
| } | |||||
| return; | |||||
| } | |||||
| __m256 alpha_256 = _mm256_broadcastss_ps(_mm_load_ss(&alpha)); | __m256 alpha_256 = _mm256_broadcastss_ps(_mm_load_ss(&alpha)); | ||||
| #if !defined(B0) | #if !defined(B0) | ||||
| __m256 beta_256 = _mm256_broadcastss_ps(_mm_load_ss(&beta)); | __m256 beta_256 = _mm256_broadcastss_ps(_mm_load_ss(&beta)); | ||||