From 7f89c6f353091aa09540b51158ab85a220b1192b Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sat, 23 Aug 2025 14:20:15 -0700 Subject: [PATCH] smh-based direct sgemm currently requires leading dimensions to be same as matrix dimension --- interface/gemm.c | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/interface/gemm.c b/interface/gemm.c index 62bc44246..92c75093c 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -266,6 +266,7 @@ void NAME(char *TRANSA, char *TRANSB, int transa, transb, nrowa, nrowb; blasint info; + int order = -1; char transA, transB; IFLOAT *buffer; @@ -557,15 +558,16 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS if (strcmp(gotoblas_corename(), "armv9sme") == 0 || strcmp(gotoblas_corename(), "vortexm4") == 0) // if (support_sme1()) #endif - if (order == CblasRowMajor && beta == 0 && alpha == 1.0 && TransA == CblasNoTrans && TransB == CblasNoTrans&& SGEMM_DIRECT_PERFORMANT(m,n,k)) { + if (order == CblasRowMajor && m==lda && n ==ldb && k==ldc && beta == 0 && alpha == 1.0 && TransA == CblasNoTrans && TransB == CblasNoTrans&& SGEMM_DIRECT_PERFORMANT(m,n,k)) { SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc); return; } else - if (order == CblasRowMajor && beta != 0. && (!(alpha==1.&&beta==1.)) && TransA == CblasNoTrans && TransB == CblasNoTrans&& SGEMM_DIRECT_PERFORMANT(m,n,k)) { + if (order == CblasRowMajor && m==lda && n==ldb && k==ldc && TransA == CblasNoTrans && TransB == CblasNoTrans&& SGEMM_DIRECT_PERFORMANT(m,n,k)) { SGEMM_DIRECT_ALPHA_BETA(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); return; } + #endif #endif @@ -587,9 +589,6 @@ else if ((args.m == 0) || (args.n == 0)) return; - - - #if 0 fprintf(stderr, "m = %4d n = %d k = %d lda = %4d ldb = %4d ldc = %4d\n", args.m, args.n, args.k, args.lda, args.ldb, args.ldc); @@ -626,6 +625,7 @@ else } bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N') || (NT == 'T' && inc_x == 1)); if (is_efficient_gemv) { +fprintf(stderr,"gemv_forwarding\n"); GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y); return; } @@ -649,6 +649,7 @@ else } bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N' && inc_y == 1) || (NT == 'T' && inc_x == 1)); if (is_efficient_gemv) { +fprintf(stderr,"gemv_forwarding\n"); GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y); return; }