From 63ce52ee7795d3acd06c490a21e24ab2c1152501 Mon Sep 17 00:00:00 2001 From: Ye Tao Date: Thu, 29 May 2025 10:51:29 +0000 Subject: [PATCH] change data type of bgemm alpha and beta from bfloat16 to fp32 and add makefiles changes for bgemm interface --- CONTRIBUTORS.md | 3 +- Makefile.system | 6 +- Makefile.tail | 4 +- cblas.h | 2 +- common_interface.h | 4 +- common_level3.h | 4 +- common_param.h | 4 +- driver/level3/Makefile | 21 +- driver/level3/level3.c | 18 +- driver/level3/level3_thread.c | 10 + interface/Makefile | 22 +- interface/gemm.c | 19 +- kernel/Makefile.L3 | 53 ++++ kernel/arm64/sbgemm_kernel_4x4_neoversev1.c | 2 +- .../arm64/sbgemm_kernel_4x4_neoversev1_impl.c | 4 +- kernel/generic/bgemmkernel_2x2.c | 227 ------------------ kernel/generic/gemm_beta.c | 29 ++- kernel/generic/gemmkernel_2x2.c | 75 +++++- test/Makefile | 44 +++- test/compare_sgemm_bgemm.c | 68 +++++- 20 files changed, 337 insertions(+), 282 deletions(-) delete mode 100644 kernel/generic/bgemmkernel_2x2.c diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index d8f57ef60..b9f0d16c3 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -251,6 +251,7 @@ In chronological order: * Ye Tao * [2025-02-03] Optimize SBGEMM kernel on NEOVERSEV1 * [2025-02-27] Add sbgemv_n_neon kernel + * [2025-05-17] Impl prototype of BGEMM inferface * Abhishek Kumar - * [2025-04-22] Optimise dot kernel for NEOVERSE V1 \ No newline at end of file + * [2025-04-22] Optimise dot kernel for NEOVERSE V1 diff --git a/Makefile.system b/Makefile.system index 38646c3c6..ff6b87555 100644 --- a/Makefile.system +++ b/Makefile.system @@ -1544,6 +1544,9 @@ ifeq ($(USE_TLS), 1) CCOMMON_OPT += -DUSE_TLS endif +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +CCOMMON_OPT += -DBUILD_BFLOAT16_ONLY +endif ifeq ($(BUILD_BFLOAT16), 1) CCOMMON_OPT += -DBUILD_BFLOAT16 endif @@ -1888,6 +1891,7 @@ export FUNCTION_PROFILE export TARGET_CORE export NO_AVX512 export NO_AVX2 +export BUILD_BFLOAT16_ONLY export BUILD_BFLOAT16 export NO_LSX export NO_LASX @@ -1912,7 +1916,7 @@ export ZGEMM3M_UNROLL_M export ZGEMM3M_UNROLL_N export XGEMM3M_UNROLL_M export XGEMM3M_UNROLL_N - +# Todo: add bgemm unroll factors ifdef USE_CUDA export CUDADIR diff --git a/Makefile.tail b/Makefile.tail index 54ba649db..5b5de184a 100644 --- a/Makefile.tail +++ b/Makefile.tail @@ -11,7 +11,7 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX)) HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX)) -BLASOBJS = $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS) +BLASOBJS = $(SBEXTOBJS) $(BBLASOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS) BLASOBJS_P = $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P) ifdef EXPRECISION @@ -24,6 +24,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P) endif +$(BBLASOBJS) : override CFLAGS += -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX -UBFLOAT16 -USMALL_MATRIX_OPT $(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX $(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX $(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX @@ -42,6 +43,7 @@ $(ZBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(XBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(SBEXTOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) + libs :: $(BLASOBJS) $(COMMONOBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ diff --git a/cblas.h b/cblas.h index c911331a8..25de498b0 100644 --- a/cblas.h +++ b/cblas.h @@ -475,7 +475,7 @@ void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST en OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); void cblas_bgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, - OPENBLAS_CONST bfloat16 alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST bfloat16 beta, bfloat16 *C, OPENBLAS_CONST blasint ldc); + OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, bfloat16 *C, OPENBLAS_CONST blasint ldc); #ifdef __cplusplus } diff --git a/common_interface.h b/common_interface.h index 4f2906014..ae4786acb 100644 --- a/common_interface.h +++ b/common_interface.h @@ -481,8 +481,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint xdouble *, blasint *, xdouble *, xdouble *, blasint *); /* Level 3 routines */ -void BLASFUNC(bgemm)(char *, char *, blasint *, blasint *, blasint *, bfloat16 *, - bfloat16 *, blasint *, bfloat16 *, blasint *, bfloat16 *, bfloat16 *, blasint *); +void BLASFUNC(bgemm)(char *, char *, blasint *, blasint *, blasint *, float *, + bfloat16 *, blasint *, bfloat16 *, blasint *, float *, bfloat16 *, blasint *); void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *, bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *); void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *, diff --git a/common_level3.h b/common_level3.h index 1cd088821..eaa33b2a2 100644 --- a/common_level3.h +++ b/common_level3.h @@ -54,7 +54,7 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K, int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); -int bgemm_beta(BLASLONG, BLASLONG, BLASLONG, bfloat16, +int bgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); @@ -513,7 +513,7 @@ int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); // add bgemm kernel -int bgemm_kernel(BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG); +int bgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG); diff --git a/common_param.h b/common_param.h index 5ae487929..b2b4b9cf6 100644 --- a/common_param.h +++ b/common_param.h @@ -54,8 +54,8 @@ typedef struct { int bgemm_unroll_m, bgemm_unroll_n, bgemm_unroll_mn; int bgemm_align_k; - int (*bgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG); - int (*bgemm_beta )(BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); + int (*bgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG); + int (*bgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); int (*bgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); int (*bgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); diff --git a/driver/level3/Makefile b/driver/level3/Makefile index bd8351013..02c5ab92e 100644 --- a/driver/level3/Makefile +++ b/driver/level3/Makefile @@ -52,6 +52,13 @@ ifeq ($(BUILD_BFLOAT16),1) SBBLASOBJS += sbgemm_nn.$(SUFFIX) sbgemm_nt.$(SUFFIX) sbgemm_tn.$(SUFFIX) sbgemm_tt.$(SUFFIX) endif +ifeq ($(BUILD_BFLOAT16_ONLY),1) +BBLASOBJS += bgemm_nn.$(SUFFIX) bgemm_nt.$(SUFFIX) bgemm_tn.$(SUFFIX) bgemm_tt.$(SUFFIX) +endif + +BLASOBJS += \ + gemm_nn.$(SUFFIX) gemm_nt.$(SUFFIX) gemm_tn.$(SUFFIX) gemm_tt.$(SUFFIX) + SBLASOBJS += \ sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \ strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \ @@ -376,6 +383,18 @@ endif all :: +bgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) + +bgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) + +bgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) + +bgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) + sbgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) @@ -432,8 +451,8 @@ cgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h cgemm_nr.$(SUFFIX) : gemm.c level3.c ../../param.h $(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -DCOMPLEX -DNR $< -o $(@F) - cgemm_nc.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -DCOMPLEX -DNC $< -o $(@F) cgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h diff --git a/driver/level3/level3.c b/driver/level3/level3.c index c8f6d966a..4596a8c12 100644 --- a/driver/level3/level3.c +++ b/driver/level3/level3.c @@ -43,9 +43,11 @@ #if !defined(XDOUBLE) || !defined(QUAD_PRECISION) #ifndef COMPLEX #define BETA_OPERATION(M_FROM, M_TO, N_FROM, N_TO, BETA, C, LDC) \ - GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \ - BETA[0], NULL, 0, NULL, 0, \ - (FLOAT *)(C) + ((M_FROM) + (N_FROM) * (LDC)) * COMPSIZE, LDC) + do { \ + GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \ + BETA[0], NULL, 0, NULL, 0, \ + (FLOAT *)(C) + ((M_FROM) + (N_FROM) * (LDC)) * COMPSIZE, LDC); \ + } while (0) #else #define BETA_OPERATION(M_FROM, M_TO, N_FROM, N_TO, BETA, C, LDC) \ GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \ @@ -189,7 +191,11 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){ BLASLONG k, lda, ldb, ldc; +#if defined(BUILD_BFLOAT16_ONLY) + float *alpha, *beta; +#else FLOAT *alpha, *beta; +#endif IFLOAT *a, *b; FLOAT *c; BLASLONG m_from, m_to, n_from, n_to; @@ -224,8 +230,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, ldb = LDB; ldc = LDC; +#if defined(BUILD_BFLOAT16_ONLY) + alpha = (float *)args -> alpha; + beta = (float *)args -> beta; +#else alpha = (FLOAT *)args -> alpha; beta = (FLOAT *)args -> beta; +#endif + m_from = 0; m_to = M; diff --git a/driver/level3/level3_thread.c b/driver/level3/level3_thread.c index e0bb2e122..64be542c2 100644 --- a/driver/level3/level3_thread.c +++ b/driver/level3/level3_thread.c @@ -239,7 +239,11 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, BLASLONG k, lda, ldb, ldc; BLASLONG m_from, m_to, n_from, n_to; +#if defined(BUILD_BFLOAT16_ONLY) + float *alpha, *beta; +#else FLOAT *alpha, *beta; +#endif IFLOAT *a, *b; FLOAT *c; job_t *job = (job_t *)args -> common; @@ -277,8 +281,14 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, ldb = LDB; ldc = LDC; +#if defined(BUILD_BFLOAT16_ONLY) + alpha = (float *)args -> alpha; + beta = (float *)args -> beta; +#else alpha = (FLOAT *)args -> alpha; beta = (FLOAT *)args -> beta; +#endif + /* Initialize 2D CPU distribution */ nthreads_m = args -> nthreads; diff --git a/interface/Makefile b/interface/Makefile index f09a6f46b..adf4a9a6f 100644 --- a/interface/Makefile +++ b/interface/Makefile @@ -53,6 +53,10 @@ SBBLAS3OBJS = sbgemm.$(SUFFIX) sbgemmt.$(SUFFIX) sbgemmtr.$(SUFFIX) SBEXTOBJS = sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX) endif +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +BBLAS3OBJ = bgemm.$(SUFFIX) +endif + DBLAS1OBJS = \ daxpy.$(SUFFIX) dswap.$(SUFFIX) \ dcopy.$(SUFFIX) dscal.$(SUFFIX) \ @@ -291,6 +295,10 @@ CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) cblas_sbgemmt.$(SUFFIX) cblas_sbgemmtr.$(S CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX) endif +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +CBBLAS3OBJS = cblas_bgemm.$(SUFFIX) +endif + CDBLAS1OBJS = \ cblas_idamax.$(SUFFIX) cblas_idamin.$(SUFFIX) cblas_dasum.$(SUFFIX) cblas_daxpy.$(SUFFIX) \ cblas_dcopy.$(SUFFIX) cblas_ddot.$(SUFFIX) \ @@ -388,6 +396,7 @@ SBLAS3OBJS += $(CSBLAS3OBJS) SBBLAS1OBJS += $(CSBBLAS1OBJS) SBBLAS2OBJS += $(CSBBLAS2OBJS) SBBLAS3OBJS += $(CSBBLAS3OBJS) +BBLAS3OBJ += $(CBBLAS3OBJS) DBLAS1OBJS += $(CDBLAS1OBJS) DBLAS2OBJS += $(CDBLAS2OBJS) DBLAS3OBJS += $(CDBLAS3OBJS) @@ -403,6 +412,7 @@ SBEXTOBJS += $(CSBEXTOBJS) CBAUXOBJS += $(CXERBLAOBJ) endif +BBLASOBJS = $(BBLAS3OBJ) SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS2OBJS) $(SBBLAS3OBJS) DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS) @@ -550,7 +560,7 @@ level1 : $(SBEXTOBJS) $(SBBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $ level2 : $(SBBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ -level3 : $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) +level3 : $(BBLAS3OBJ) $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ aux : $(CBAUXOBJS) @@ -1309,6 +1319,11 @@ sbgemmtr.$(SUFFIX) sbgemmtr.$(PSUFFIX) : sbgemmt.c ../param.h $(CC) -c $(CFLAGS) -DRNAME $< -o $(@F) endif +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +bgemm.$(SUFFIX) : gemm.c ../param.h + $(CC) -c $(CFLAGS) $< -o $(@F) +endif + sgemm.$(SUFFIX) sgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -c $(CFLAGS) $< -o $(@F) @@ -1968,6 +1983,11 @@ cblas_sbgemm.$(SUFFIX) cblas_sbgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) endif +ifeq ($(BUILD_BFLOAT16_ONLY),1) +cblas_bgemm.$(SUFFIX) cblas_bgemm.$(PSUFFIX) : gemm.c ../param.h + $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) +endif + cblas_dgemm.$(SUFFIX) cblas_dgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) diff --git a/interface/gemm.c b/interface/gemm.c index bc1631d9a..d08304b57 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -250,6 +250,15 @@ static inline int get_gemm_optimal_nthreads(double MNK) { #ifndef CBLAS +#ifdef BFLOAT16_ONLY +void NAME(char *TRANSA, char *TRANSB, + blasint *M, blasint *N, blasint *K, + float *alpha, + IFLOAT *a, blasint *ldA, + IFLOAT *b, blasint *ldB, + float *beta, + FLOAT *c, blasint *ldC){ +#else void NAME(char *TRANSA, char *TRANSB, blasint *M, blasint *N, blasint *K, FLOAT *alpha, @@ -257,7 +266,7 @@ void NAME(char *TRANSA, char *TRANSB, IFLOAT *b, blasint *ldB, FLOAT *beta, FLOAT *c, blasint *ldC){ - +#endif blas_arg_t args; int transa, transb, nrowa, nrowb; @@ -366,11 +375,19 @@ void NAME(char *TRANSA, char *TRANSB, void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, blasint m, blasint n, blasint k, #ifndef COMPLEX + #ifdef BFLOAT16_ONLY + float alpha, + IFLOAT *a, blasint lda, + IFLOAT *b, blasint ldb, + float beta, + FLOAT *c, blasint ldc) { + #else FLOAT alpha, IFLOAT *a, blasint lda, IFLOAT *b, blasint ldb, FLOAT beta, FLOAT *c, blasint ldc) { + #endif #else void *valpha, void *va, blasint lda, diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 94b66a17c..fea5b0ea9 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -136,6 +136,25 @@ endif endif endif +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +ifndef BGEMMKERNEL +BGEMM_BETA = ../generic/gemm_beta.c +BGEMMKERNEL = ../generic/gemmkernel_2x2.c +BGEMMINCOPY = ../generic/gemm_ncopy_2.c +BGEMMITCOPY = ../generic/gemm_tcopy_2.c +BGEMMONCOPY = ../generic/gemm_ncopy_2.c +BGEMMOTCOPY = ../generic/gemm_tcopy_2.c +BGEMMINCOPYOBJ = bgemm_incopy$(TSUFFIX).$(SUFFIX) +BGEMMITCOPYOBJ = bgemm_itcopy$(TSUFFIX).$(SUFFIX) +BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX) +BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX) +endif +BKERNELOBJS += \ + bgemm_kernel$(TSUFFIX).$(SUFFIX) \ + $(BGEMMINCOPYOBJ) $(BGEMMITCOPYOBJ) \ + $(BGEMMONCOPYOBJ) $(BGEMMOTCOPYOBJ) +endif + ifeq ($(BUILD_BFLOAT16), 1) ifndef SBGEMMKERNEL SBGEMM_BETA = ../generic/gemm_beta.c @@ -216,6 +235,11 @@ XKERNELOBJS += \ $(XGEMMINCOPYOBJ) $(XGEMMITCOPYOBJ) \ $(XGEMMONCOPYOBJ) $(XGEMMOTCOPYOBJ) + +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +BBLASOBJS += $(BKERNELOBJS) +endif + ifeq ($(BUILD_BFLOAT16),1) SBBLASOBJS += $(SBKERNELOBJS) endif @@ -226,6 +250,10 @@ CBLASOBJS += $(CKERNELOBJS) ZBLASOBJS += $(ZKERNELOBJS) XBLASOBJS += $(XKERNELOBJS) +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +BBLASOBJS += bgemm_beta$(TSUFFIX).$(SUFFIX) +endif + ifeq ($(BUILD_BFLOAT16),1) SBBLASOBJS += sbgemm_beta$(TSUFFIX).$(SUFFIX) endif @@ -651,6 +679,11 @@ XGEMMITCOPYOBJ_P = $(XGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) XGEMMONCOPYOBJ_P = $(XGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) XGEMMOTCOPYOBJ_P = $(XGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) +ifeq ($(BUILD_BFLOAT16_ONLY),1) +$(KDIR)bgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BGEMM_BETA) + $(CC) $(CFLAGS) -c -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ +endif + ifeq ($(BUILD_BFLOAT16),1) $(KDIR)sbgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_BETA) $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ @@ -678,6 +711,21 @@ ifeq ($(ARCH), E2K) USE_TRMM = 1 endif +ifeq ($(BUILD_BFLOAT16_ONLY), 1) + +$(KDIR)$(BGEMMONCOPYOBJ) : $(KERNELDIR)/$(BGEMMONCOPY) + $(CC) $(CFLAGS) -c -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)$(BGEMMOTCOPYOBJ) : $(KERNELDIR)/$(BGEMMOTCOPY) + $(CC) $(CFLAGS) -c -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)$(BGEMMINCOPYOBJ) : $(KERNELDIR)/$(BGEMMINCOPY) + $(CC) $(CFLAGS) -c -DDBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)$(BGEMMITCOPYOBJ) : $(KERNELDIR)/$(BGEMMITCOPY) + $(CC) $(CFLAGS) -c -DDBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ + +endif ifeq ($(BUILD_BFLOAT16), 1) @@ -874,6 +922,11 @@ endif endif endif +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +$(KDIR)bgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BGEMMKERNEL) + $(CC) $(CFLAGS) -c -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ +endif + ifeq ($(BUILD_BFLOAT16), 1) $(KDIR)sbgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMMKERNEL) $(SBGEMMDEPEND) diff --git a/kernel/arm64/sbgemm_kernel_4x4_neoversev1.c b/kernel/arm64/sbgemm_kernel_4x4_neoversev1.c index 889b5fc5b..772e45da9 100644 --- a/kernel/arm64/sbgemm_kernel_4x4_neoversev1.c +++ b/kernel/arm64/sbgemm_kernel_4x4_neoversev1.c @@ -35,7 +35,7 @@ #undef ALPHA_ONE #include "sbgemm_kernel_4x4_neoversev1_impl.c" -int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B, +int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc) { if (alpha == 1.0f) return sbgemm_kernel_neoversev1_alpha_one(m, n, k, alpha, A, B, C, ldc); diff --git a/kernel/arm64/sbgemm_kernel_4x4_neoversev1_impl.c b/kernel/arm64/sbgemm_kernel_4x4_neoversev1_impl.c index b6d9e9816..02b101f11 100644 --- a/kernel/arm64/sbgemm_kernel_4x4_neoversev1_impl.c +++ b/kernel/arm64/sbgemm_kernel_4x4_neoversev1_impl.c @@ -78,11 +78,11 @@ #ifdef ALPHA_ONE int sbgemm_kernel_neoversev1_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, - FLOAT alpha, IFLOAT *A, IFLOAT *B, + float alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc) #else int sbgemm_kernel_neoversev1_alpha(BLASLONG m, BLASLONG n, BLASLONG k, - FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, + float alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc) #endif { diff --git a/kernel/generic/bgemmkernel_2x2.c b/kernel/generic/bgemmkernel_2x2.c deleted file mode 100644 index 5fe8d3255..000000000 --- a/kernel/generic/bgemmkernel_2x2.c +++ /dev/null @@ -1,227 +0,0 @@ -/*************************************************************************** - * Copyright (c) 2025, The OpenBLAS Project - * All rights reserved. - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in - * the documentation and/or other materials provided with the - * distribution. - * 3. Neither the name of the OpenBLAS project nor the names of - * its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - * *****************************************************************************/ - -#include "common.h" - -static float bfloat16tof32(bfloat16 f16) { - float result = 0; - unsigned short *q = (unsigned short *)(&result); -#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - q[0] = f16; -#else - q[1] = f16; -#endif - return result; -} - -static bfloat16 f32tobfloat16(float f32) { - unsigned short *q = (unsigned short *)(&f32); -#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - return q[0]; -#else - return q[1]; -#endif -} - -#define BF16TOF32(x) (bfloat16tof32(x)) -#define F32TOBF16(x) (f32tobfloat16(x)) - -int CNAME(BLASLONG bm, BLASLONG bn, BLASLONG bk, FLOAT alpha, IFLOAT *ba, - IFLOAT *bb, FLOAT *C, BLASLONG ldc) { - BLASLONG i, j, k; - FLOAT *C0, *C1; // bfloat16 - IFLOAT *ptrba, *ptrbb; - float res0, res1, res2, res3; - float load0, load1, load2, load3, load4, load5, load6, load7; - - float alpha_ = BF16TOF32(alpha); - - for (j = 0; j < bn / 2; j += 1) { - C0 = C; - C1 = C0 + ldc; - ptrba = ba; - for (i = 0; i < bm / 2; i += 1) { - ptrbb = bb; - res0 = 0; - res1 = 0; - res2 = 0; - res3 = 0; - for (k = 0; k < bk / 4; k += 1) { - load0 = BF16TOF32(ptrba[2 * 0 + 0]); - load2 = BF16TOF32(ptrba[2 * 0 + 1]); - load4 = BF16TOF32(ptrba[2 * 1 + 0]); - load6 = BF16TOF32(ptrba[2 * 1 + 1]); - - load1 = BF16TOF32(ptrbb[2 * 0 + 0]); - load3 = BF16TOF32(ptrbb[2 * 0 + 1]); - load5 = BF16TOF32(ptrbb[2 * 1 + 0]); - load7 = BF16TOF32(ptrbb[2 * 1 + 1]); - - res0 = res0 + load0 * load1; - res1 = res1 + load2 * load1; - res2 = res2 + load0 * load3; - res3 = res3 + load2 * load3; - - res0 = res0 + load4 * load5; - res1 = res1 + load6 * load5; - res2 = res2 + load4 * load7; - res3 = res3 + load6 * load7; - - load0 = BF16TOF32(ptrba[2 * 2 + 0]); - load2 = BF16TOF32(ptrba[2 * 2 + 1]); - load4 = BF16TOF32(ptrba[2 * 3 + 0]); - load6 = BF16TOF32(ptrba[2 * 3 + 1]); - - load1 = BF16TOF32(ptrbb[2 * 2 + 0]); - load3 = BF16TOF32(ptrbb[2 * 2 + 1]); - load5 = BF16TOF32(ptrbb[2 * 3 + 0]); - load7 = BF16TOF32(ptrbb[2 * 3 + 1]); - - res0 = res0 + load0 * load1; - res1 = res1 + load2 * load1; - res2 = res2 + load0 * load3; - res3 = res3 + load2 * load3; - - res0 = res0 + load4 * load5; - res1 = res1 + load6 * load5; - res2 = res2 + load4 * load7; - res3 = res3 + load6 * load7; - } - - for (k = 0; k < (bk & 3); k += 1) { - load0 = BF16TOF32(ptrba[2 * 0 + 0]); - load2 = BF16TOF32(ptrba[2 * 0 + 1]); - load1 = BF16TOF32(ptrbb[2 * 0 + 0]); - load3 = BF16TOF32(ptrbb[2 * 0 + 1]); - - res0 = res0 + load0 * load1; - res1 = res1 + load2 * load1; - res2 = res2 + load0 * load3; - res3 = res3 + load2 * load3; - - ptrba = ptrba + 2; - ptrbb = ptrbb + 2; - } - - res0 = res0 * alpha_ + BF16TOF32(C0[0]); - res1 = res1 * alpha_ + BF16TOF32(C0[1]); - res2 = res2 * alpha_ + BF16TOF32(C1[0]); - res3 = res3 * alpha_ + BF16TOF32(C1[1]); - - C0[0] = F32TOBF16(res0); - C0[1] = F32TOBF16(res1); - C1[0] = F32TOBF16(res2); - C1[1] = F32TOBF16(res3); - - C0 = C0 + 2; - C1 = C1 + 2; - } - - for (i = 0; i < (bm & 1); i += 1) { - ptrbb = bb; - res0 = 0; - res1 = 0; - for (k = 0; k < bk; k += 1) { - load0 = BF16TOF32(ptrba[0 + 0]); - load1 = BF16TOF32(ptrbb[2 * 0 + 0]); - load2 = BF16TOF32(ptrbb[2 * 0 + 1]); - - res0 = res0 + load0 * load1; - res1 = res1 + load0 * load2; - - ptrba = ptrba + 1; - ptrbb = ptrbb + 2; - } - - res0 = res0 * alpha_ + BF16TOF32(C0[0]); - res1 = res1 * alpha_ + BF16TOF32(C1[0]); - - C0[0] = res0; - C1[0] = res1; - - C0 = C0 + 1; - C1 = C1 + 1; - } - - k = (bk << 1); - bb = bb + k; - i = (ldc << 1); - C = C + i; - } - - for (j = 0; j < (bn & 1); j += 1) { - C0 = C; - ptrba = ba; - for (i = 0; i < bm / 2; i += 1) { - ptrbb = bb; - res0 = 0; - res1 = 0; - for (k = 0; k < bk; k += 1) { - load0 = BF16TOF32(ptrba[2 * 0 + 0]); - load2 = BF16TOF32(ptrba[2 * 0 + 1]); - - load1 = BF16TOF32(ptrbb[0 + 0]); - - res0 = res0 + load0 * load1; - res1 = res1 + load2 * load1; - - ptrba = ptrba + 2; - ptrbb = ptrbb + 1; - } - - res0 = res0 * alpha_ + BF16TOF32(C0[0]); - res1 = res1 * alpha_ + BF16TOF32(C0[1]); - - C0[0] = F32TOBF16(res0); - C0[1] = F32TOBF16(res1); - C0 = C0 + 2; - } - - for (i = 0; i < (bm & 1); i += 1) { - ptrbb = bb; - res0 = 0; - for (k = 0; k < bk; k += 1) { - load0 = BF16TOF32(ptrba[0 + 0]); - load1 = BF16TOF32(ptrbb[0 + 0]); - res0 += load0 * load1; - ptrba = ptrba + 1; - ptrbb = ptrbb + 1; - } - - res0 = res0 * alpha_ + BF16TOF32(C0[0]); - C0[0] = F32TOBF16(res0); - C0 = C0 + 1; - } - - k = (bk << 0); - bb = bb + k; - C = C + ldc; - } - - return 0; -} diff --git a/kernel/generic/gemm_beta.c b/kernel/generic/gemm_beta.c index f399de090..36522ad22 100644 --- a/kernel/generic/gemm_beta.c +++ b/kernel/generic/gemm_beta.c @@ -71,11 +71,15 @@ f32tobfloat16(float f32) #define F32TOBF16(x) x #endif - -int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta_in, +#if defined(BFLOAT16_ONLY) +int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, float beta, IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, FLOAT *c, BLASLONG ldc){ - +#else +int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta, + IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, + FLOAT *c, BLASLONG ldc){ +#endif BLASLONG i, j; BLASLONG chunk, remain; @@ -83,25 +87,24 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta_in, c_offset = c; chunk = m >> 3; remain = m & 7; - float beta = BF16TOF32(beta_in); if (beta == ZERO){ for(j=n; j>0; j--){ c_offset1 = c_offset; c_offset += ldc; for(i=chunk; i>0; i--){ - *(c_offset1 + 0) = ZERO; - *(c_offset1 + 1) = ZERO; - *(c_offset1 + 2) = ZERO; - *(c_offset1 + 3) = ZERO; - *(c_offset1 + 4) = ZERO; - *(c_offset1 + 5) = ZERO; - *(c_offset1 + 6) = ZERO; - *(c_offset1 + 7) = ZERO; + *(c_offset1 + 0) = F32TOBF16(ZERO); + *(c_offset1 + 1) = F32TOBF16(ZERO); + *(c_offset1 + 2) = F32TOBF16(ZERO); + *(c_offset1 + 3) = F32TOBF16(ZERO); + *(c_offset1 + 4) = F32TOBF16(ZERO); + *(c_offset1 + 5) = F32TOBF16(ZERO); + *(c_offset1 + 6) = F32TOBF16(ZERO); + *(c_offset1 + 7) = F32TOBF16(ZERO); c_offset1 += 8; } for(i=remain; i>0; i--){ - *c_offset1 = ZERO; + *c_offset1 = F32TOBF16(ZERO); c_offset1 ++; } } diff --git a/kernel/generic/gemmkernel_2x2.c b/kernel/generic/gemmkernel_2x2.c index bf1c3ae38..3cf2d928f 100644 --- a/kernel/generic/gemmkernel_2x2.c +++ b/kernel/generic/gemmkernel_2x2.c @@ -1,5 +1,32 @@ +/*************************************************************************** +Copyright (c) 2025 The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + #include "common.h" -#if defined(BFLOAT16) && defined(BFLOAT16CONVERSION) +#if (defined(BFLOAT16) || defined(BFLOAT16_ONLY))&& defined(BFLOAT16CONVERSION) static float bfloat16tof32 (bfloat16 f16) { @@ -12,12 +39,29 @@ bfloat16tof32 (bfloat16 f16) #endif return result; } + +static bfloat16 f32tobfloat16(float f32) { + unsigned short *q = (unsigned short *)(&f32); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + return q[0]; +#else + return q[1]; +#endif +} + #define BF16TOF32(x) (bfloat16tof32(x)) +#define F32TOBF16(x) (f32tobfloat16(x)) #else #define BF16TOF32(x) x +#define F32TOBF16(x) x #endif + +#ifdef BFLOAT16_ONLY +int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk, float alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc +#else int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc -#ifdef TRMMKERNEL +#endif + #ifdef TRMMKERNEL ,BLASLONG offset #endif ) @@ -90,13 +134,17 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, ptrbb = ptrbb+2; } res0 = res0*alpha; - C0[0] = C0[0]+res0; + C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[0])); res1 = res1*alpha; - C0[1] = C0[1]+res1; + C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[1])); res2 = res2*alpha; - C1[0] = C1[0]+res2; + C1[0] = F32TOBF16(BF16TOF32(C1[0])+res2); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C1[0])); res3 = res3*alpha; - C1[1] = C1[1]+res3; + C1[1] = F32TOBF16(BF16TOF32(C1[1])+res3); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C1[1])); C0 = C0+2; C1 = C1+2; } @@ -116,9 +164,11 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, ptrbb = ptrbb+2; } res0 = res0*alpha; - C0[0] = C0[0]+res0; + C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[0])); res1 = res1*alpha; - C1[0] = C1[0]+res1; + C1[0] = F32TOBF16(BF16TOF32(C1[1])+res1); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C1[0])); C0 = C0+1; C1 = C1+1; } @@ -147,9 +197,11 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, ptrbb = ptrbb+1; } res0 = res0*alpha; - C0[0] = C0[0]+res0; + C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[0])); res1 = res1*alpha; - C0[1] = C0[1]+res1; + C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[1])); C0 = C0+2; } for (i=0; i<(bm&1); i+=1) @@ -165,7 +217,8 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, ptrbb = ptrbb+1; } res0 = res0*alpha; - C0[0] = C0[0]+res0; + C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[0])); C0 = C0+1; } k = (bk<<0); diff --git a/test/Makefile b/test/Makefile index 9ba88988b..e9a77dc05 100644 --- a/test/Makefile +++ b/test/Makefile @@ -1,3 +1,31 @@ +############################################################################### +# Copyright (c) 2025, The OpenBLAS Project +# All rights reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. +# 3. Neither the name of the OpenBLAS project nor the names of +# its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +############################################################################### TOPDIR = .. include ../Makefile.system ifeq ($(F_COMPILER),GFORTRAN) @@ -164,6 +192,9 @@ endif endif endif +ifeq ($(BUILD_BFLOAT16_ONLY),1) +BF3= test_bgemm +endif ifeq ($(BUILD_BFLOAT16),1) B3= test_sbgemm endif @@ -192,11 +223,15 @@ endif ifeq ($(SUPPORT_GEMM3M),1) level3: $(B3) $(S3) $(D3) $(C3) $(Z3) level3_3m else -level3: $(B3) $(S3) $(D3) $(C3) $(Z3) +level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3) endif ifneq ($(CROSS), 1) rm -f ?BLAT3.SUMM +ifeq ($(BUILD_BFLOAT16_ONLY),1) + OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_bgemm > BBLAT3.SUMM + @$(GREP) -q FATAL SBBLAT3.SUMM && cat BBLAT3.SUMM || exit 0 +endif ifeq ($(BUILD_BFLOAT16),1) OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_sbgemm > SBBLAT3.SUMM @$(GREP) -q FATAL SBBLAT3.SUMM && cat SBBLAT3.SUMM || exit 0 @@ -366,6 +401,11 @@ zblat3 : zblat3.$(SUFFIX) ../$(LIBNAME) $(FC) $(FLDFLAGS) -o zblat3 zblat3.$(SUFFIX) ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) endif +ifeq ($(BUILD_BFLOAT16_ONLY),1) +test_bgemm : compare_sgemm_bgemm.c ../$(LIBNAME) + $(CC) $(CLDFLAGS) -o test_bgemm compare_sgemm_bgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) +endif + ifeq ($(BUILD_BFLOAT16),1) test_sbgemm : compare_sgemm_sbgemm.c ../$(LIBNAME) $(CC) $(CLDFLAGS) -o test_sbgemm compare_sgemm_sbgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) @@ -387,7 +427,7 @@ clean: @rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \ sblat1 dblat1 cblat1 zblat1 \ sblat2 dblat2 cblat2 zblat2 \ - test_sbgemm sblat3 dblat3 cblat3 zblat3 \ + test_bgemm test_sbgemm sblat3 dblat3 cblat3 zblat3 \ sblat1p dblat1p cblat1p zblat1p \ sblat2p dblat2p cblat2p zblat2p \ sblat3p dblat3p cblat3p zblat3p \ diff --git a/test/compare_sgemm_bgemm.c b/test/compare_sgemm_bgemm.c index e5a2ba46d..a399efa5f 100644 --- a/test/compare_sgemm_bgemm.c +++ b/test/compare_sgemm_bgemm.c @@ -57,10 +57,8 @@ int main(int argc, char *argv[]) { char transA = 'N', transB = 'N'; float alpha = 1.0, beta = 0.0; - bfloat16 alpha_bf16 = convert_to_bf16(alpha), - beta_bf16 = convert_to_bf16(beta); - for (x = 1; x <= BGEMM_LARGEST; x++) { + for (x = 1; x <= loop; x++) { if ((x > 100) && (x != BGEMM_LARGEST)) continue; m = k = n = x; @@ -79,14 +77,14 @@ int main(int argc, char *argv[]) { for (int i = 0; i < m; i++) { for (int j = 0; j < k; j++) { - AA[i * k + j] = (i * k + j + 1) % 100; + AA[i * k + j] = (i * k + j + 1) % 5; A[i * k + j] = AA[i * k + j]; } } for (int i = 0; i < n; i++) { for (int j = 0; j < k; j++) { - BB[i * k + j] = (i * k + j + 1) % 100; + BB[i * k + j] = (i * k + j + 1) % 5; B[i * k + j] = BB[i * k + j]; } } @@ -102,23 +100,73 @@ int main(int argc, char *argv[]) { } else { transB = 'T'; } + // printf("******** x = %d, y = %d********\n", x, y); + // printf("Matrix AA (m x k):\n"); + // for (int i = 0; i < m; i++) { + // for (int j = 0; j < k; j++) { + // printf("%.2f ", (float)AA[i * k + j]); // or %4.1f if float + // } + // printf("\n"); + // } + + // printf("Matrix A (copy of AA):\n"); + // for (int i = 0; i < m; i++) { + // for (int j = 0; j < k; j++) { + // printf("%.2f ", A[i * k + j]); + // } + // printf("\n"); + // } + + // printf("Matrix BB (n x k):\n"); + // for (int i = 0; i < n; i++) { + // for (int j = 0; j < k; j++) { + // printf("%.2f ", (float)BB[i * k + j]); + // } + // printf("\n"); + // } + + // printf("Matrix B (copy of BB):\n"); + // for (int i = 0; i < n; i++) { + // for (int j = 0; j < k; j++) { + // printf("%.2f ", B[i * k + j]); + // } + // printf("\n"); + // } memset(C, 0, m * n * sizeof(FLOAT)); memset(CC, 0, m * n * sizeof(bfloat16)); - SGEMM(&transA, &transB, &m, &n, &k, &alpha_bf16, A, &m, B, &k, &beta_bf16, + SGEMM(&transA, &transB, &m, &n, &k, &alpha, A, &m, B, &k, &beta, C, &m); BGEMM(&transA, &transB, &m, &n, &k, &alpha, (bfloat16 *)AA, &m, (bfloat16 *)BB, &k, &beta, (bfloat16 *)CC, &m); + + + // printf("Matrix CC (n x m):\n"); + // for (int i = 0; i < n; i++) { + // for (int j = 0; j < m; j++) { + // printf("%.2f ", (float)CC[i * m + j]); + // } + // printf("\n"); + // } + + // printf("Matrix C :\n"); + // for (int i = 0; i < n; i++) { + // for (int j = 0; j < k; j++) { + // printf("%.2f ", C[i * k + j]); + // } + // printf("\n"); + // } for (i = 0; i < n; i++) { for (j = 0; j < m; j++) { - for (l = 0; l < k; l++) { - if (fabs(CC[i * m + j] - C[i * m + j]) > 1.0) { - ret++; + if (fabs((float)CC[i * m + j] - C[i * m + j]) > 1.0) { + ret ++; } - } } } + + printf("x = %d, err = %d\n", x, ret); + ret = 0; } free(A);