diff --git a/interface/gemm.c b/interface/gemm.c index c5182c266..62bc44246 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -424,30 +424,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS PRINT_DEBUG_CNAME; -#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16) -#if defined(ARCH_x86) && (defined(USE_SGEMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH)) -#if defined(DYNAMIC_ARCH) - if (support_avx512() ) -#endif - if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && SGEMM_DIRECT_PERFORMANT(m,n,k)) { - SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc); - return; - } -#endif -#if defined(ARCH_ARM64) && (defined(USE_SGEMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH)) -#if defined(DYNAMIC_ARCH) - if (support_sme1()) -#endif - if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) { - SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc); - return; - }else if (order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) { - SGEMM_DIRECT_ALPHA_BETA(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - return; - } -#endif -#endif - #ifndef COMPLEX args.alpha = (void *)α args.beta = (void *)β @@ -564,6 +540,35 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS return; } + + if ((args.m == 0) || (args.n == 0)) return; +#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16) +#if defined(ARCH_x86) && (defined(USE_SGEMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH)) +#if defined(DYNAMIC_ARCH) + if (support_avx512() ) +#endif + if (order == CblasRowMajor && beta == 0 && alpha == 1.0 && TransA == CblasNoTrans && TransB == CblasNoTrans && SGEMM_DIRECT_PERFORMANT(m,n,k)) { + SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc); + return; + } +#endif +#if defined(ARCH_ARM64) && (defined(USE_SGEMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH)) +#if defined(DYNAMIC_ARCH) +if (strcmp(gotoblas_corename(), "armv9sme") == 0 || strcmp(gotoblas_corename(), "vortexm4") == 0) +// if (support_sme1()) +#endif + if (order == CblasRowMajor && beta == 0 && alpha == 1.0 && TransA == CblasNoTrans && TransB == CblasNoTrans&& SGEMM_DIRECT_PERFORMANT(m,n,k)) { + SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc); + return; + } +else + if (order == CblasRowMajor && beta != 0. && (!(alpha==1.&&beta==1.)) && TransA == CblasNoTrans && TransB == CblasNoTrans&& SGEMM_DIRECT_PERFORMANT(m,n,k)) { + SGEMM_DIRECT_ALPHA_BETA(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + return; + } +#endif +#endif + #endif #if defined(__linux__) && defined(__x86_64__) && defined(BFLOAT16) @@ -582,6 +587,9 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS if ((args.m == 0) || (args.n == 0)) return; + + + #if 0 fprintf(stderr, "m = %4d n = %d k = %d lda = %4d ldb = %4d ldc = %4d\n", args.m, args.n, args.k, args.lda, args.ldb, args.ldc);