Browse Source

Move SGEMM_DIRECT after the CBLAS parameter check and add sgemm_direct_performant for ARM64

pull/5423/head
Martin Kroeker GitHub 9 months ago
parent
commit
de91afd2ae
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
1 changed files with 32 additions and 24 deletions
  1. +32
    -24
      interface/gemm.c

+ 32
- 24
interface/gemm.c View File

@@ -424,30 +424,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS

PRINT_DEBUG_CNAME;

#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
#if defined(ARCH_x86) && (defined(USE_SGEMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
#if defined(DYNAMIC_ARCH)
if (support_avx512() )
#endif
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && SGEMM_DIRECT_PERFORMANT(m,n,k)) {
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
return;
}
#endif
#if defined(ARCH_ARM64) && (defined(USE_SGEMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
#if defined(DYNAMIC_ARCH)
if (support_sme1())
#endif
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) {
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
return;
}else if (order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) {
SGEMM_DIRECT_ALPHA_BETA(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
return;
}
#endif
#endif

#ifndef COMPLEX
args.alpha = (void *)α
args.beta = (void *)β
@@ -564,6 +540,35 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
return;
}


if ((args.m == 0) || (args.n == 0)) return;
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
#if defined(ARCH_x86) && (defined(USE_SGEMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
#if defined(DYNAMIC_ARCH)
if (support_avx512() )
#endif
if (order == CblasRowMajor && beta == 0 && alpha == 1.0 && TransA == CblasNoTrans && TransB == CblasNoTrans && SGEMM_DIRECT_PERFORMANT(m,n,k)) {
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
return;
}
#endif
#if defined(ARCH_ARM64) && (defined(USE_SGEMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
#if defined(DYNAMIC_ARCH)
if (strcmp(gotoblas_corename(), "armv9sme") == 0 || strcmp(gotoblas_corename(), "vortexm4") == 0)
// if (support_sme1())
#endif
if (order == CblasRowMajor && beta == 0 && alpha == 1.0 && TransA == CblasNoTrans && TransB == CblasNoTrans&& SGEMM_DIRECT_PERFORMANT(m,n,k)) {
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
return;
}
else
if (order == CblasRowMajor && beta != 0. && (!(alpha==1.&&beta==1.)) && TransA == CblasNoTrans && TransB == CblasNoTrans&& SGEMM_DIRECT_PERFORMANT(m,n,k)) {
SGEMM_DIRECT_ALPHA_BETA(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
return;
}
#endif
#endif

#endif

#if defined(__linux__) && defined(__x86_64__) && defined(BFLOAT16)
@@ -582,6 +587,9 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS

if ((args.m == 0) || (args.n == 0)) return;




#if 0
fprintf(stderr, "m = %4d n = %d k = %d lda = %4d ldb = %4d ldc = %4d\n",
args.m, args.n, args.k, args.lda, args.ldb, args.ldc);


Loading…
Cancel
Save