| @@ -251,6 +251,7 @@ In chronological order: | |||
| * Ye Tao <ye.tao@arm.com> | |||
| * [2025-02-03] Optimize SBGEMM kernel on NEOVERSEV1 | |||
| * [2025-02-27] Add sbgemv_n_neon kernel | |||
| * [2025-05-17] Impl prototype of BGEMM inferface | |||
| * Abhishek Kumar <https://github.com/abhishek-iitmadras> | |||
| * [2025-04-22] Optimise dot kernel for NEOVERSE V1 | |||
| * [2025-04-22] Optimise dot kernel for NEOVERSE V1 | |||
| @@ -1544,6 +1544,9 @@ ifeq ($(USE_TLS), 1) | |||
| CCOMMON_OPT += -DUSE_TLS | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||
| CCOMMON_OPT += -DBUILD_BFLOAT16_ONLY | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16), 1) | |||
| CCOMMON_OPT += -DBUILD_BFLOAT16 | |||
| endif | |||
| @@ -1888,6 +1891,7 @@ export FUNCTION_PROFILE | |||
| export TARGET_CORE | |||
| export NO_AVX512 | |||
| export NO_AVX2 | |||
| export BUILD_BFLOAT16_ONLY | |||
| export BUILD_BFLOAT16 | |||
| export NO_LSX | |||
| export NO_LASX | |||
| @@ -1912,7 +1916,7 @@ export ZGEMM3M_UNROLL_M | |||
| export ZGEMM3M_UNROLL_N | |||
| export XGEMM3M_UNROLL_M | |||
| export XGEMM3M_UNROLL_N | |||
| # Todo: add bgemm unroll factors | |||
| ifdef USE_CUDA | |||
| export CUDADIR | |||
| @@ -11,7 +11,7 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||
| HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||
| BLASOBJS = $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS) | |||
| BLASOBJS = $(SBEXTOBJS) $(BBLASOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS) | |||
| BLASOBJS_P = $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P) | |||
| ifdef EXPRECISION | |||
| @@ -24,6 +24,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) | |||
| BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P) | |||
| endif | |||
| $(BBLASOBJS) : override CFLAGS += -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX -UBFLOAT16 -USMALL_MATRIX_OPT | |||
| $(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX | |||
| $(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX | |||
| $(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX | |||
| @@ -42,6 +43,7 @@ $(ZBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | |||
| $(XBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | |||
| $(SBEXTOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | |||
| libs :: $(BLASOBJS) $(COMMONOBJS) | |||
| $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | |||
| @@ -475,7 +475,7 @@ void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST en | |||
| OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); | |||
| void cblas_bgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, | |||
| OPENBLAS_CONST bfloat16 alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST bfloat16 beta, bfloat16 *C, OPENBLAS_CONST blasint ldc); | |||
| OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, bfloat16 *C, OPENBLAS_CONST blasint ldc); | |||
| #ifdef __cplusplus | |||
| } | |||
| @@ -481,8 +481,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint | |||
| xdouble *, blasint *, xdouble *, xdouble *, blasint *); | |||
| /* Level 3 routines */ | |||
| void BLASFUNC(bgemm)(char *, char *, blasint *, blasint *, blasint *, bfloat16 *, | |||
| bfloat16 *, blasint *, bfloat16 *, blasint *, bfloat16 *, bfloat16 *, blasint *); | |||
| void BLASFUNC(bgemm)(char *, char *, blasint *, blasint *, blasint *, float *, | |||
| bfloat16 *, blasint *, bfloat16 *, blasint *, float *, bfloat16 *, blasint *); | |||
| void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *, | |||
| bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *); | |||
| void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *, | |||
| @@ -54,7 +54,7 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K, | |||
| int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); | |||
| int bgemm_beta(BLASLONG, BLASLONG, BLASLONG, bfloat16, | |||
| int bgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, | |||
| bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); | |||
| int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, | |||
| bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); | |||
| @@ -513,7 +513,7 @@ int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl | |||
| int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); | |||
| // add bgemm kernel | |||
| int bgemm_kernel(BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG); | |||
| int bgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG); | |||
| int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); | |||
| int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); | |||
| int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG); | |||
| @@ -54,8 +54,8 @@ typedef struct { | |||
| int bgemm_unroll_m, bgemm_unroll_n, bgemm_unroll_mn; | |||
| int bgemm_align_k; | |||
| int (*bgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG); | |||
| int (*bgemm_beta )(BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); | |||
| int (*bgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG); | |||
| int (*bgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); | |||
| int (*bgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); | |||
| int (*bgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); | |||
| @@ -52,6 +52,13 @@ ifeq ($(BUILD_BFLOAT16),1) | |||
| SBBLASOBJS += sbgemm_nn.$(SUFFIX) sbgemm_nt.$(SUFFIX) sbgemm_tn.$(SUFFIX) sbgemm_tt.$(SUFFIX) | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16_ONLY),1) | |||
| BBLASOBJS += bgemm_nn.$(SUFFIX) bgemm_nt.$(SUFFIX) bgemm_tn.$(SUFFIX) bgemm_tt.$(SUFFIX) | |||
| endif | |||
| BLASOBJS += \ | |||
| gemm_nn.$(SUFFIX) gemm_nt.$(SUFFIX) gemm_tn.$(SUFFIX) gemm_tt.$(SUFFIX) | |||
| SBLASOBJS += \ | |||
| sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \ | |||
| strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \ | |||
| @@ -376,6 +383,18 @@ endif | |||
| all :: | |||
| bgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h | |||
| $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) | |||
| bgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h | |||
| $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) | |||
| bgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h | |||
| $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) | |||
| bgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h | |||
| $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) | |||
| sbgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h | |||
| $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) | |||
| @@ -432,8 +451,8 @@ cgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h | |||
| cgemm_nr.$(SUFFIX) : gemm.c level3.c ../../param.h | |||
| $(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -DCOMPLEX -DNR $< -o $(@F) | |||
| cgemm_nc.$(SUFFIX) : gemm.c level3.c ../../param.h | |||
| $(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -DCOMPLEX -DNC $< -o $(@F) | |||
| cgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h | |||
| @@ -43,9 +43,11 @@ | |||
| #if !defined(XDOUBLE) || !defined(QUAD_PRECISION) | |||
| #ifndef COMPLEX | |||
| #define BETA_OPERATION(M_FROM, M_TO, N_FROM, N_TO, BETA, C, LDC) \ | |||
| GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \ | |||
| BETA[0], NULL, 0, NULL, 0, \ | |||
| (FLOAT *)(C) + ((M_FROM) + (N_FROM) * (LDC)) * COMPSIZE, LDC) | |||
| do { \ | |||
| GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \ | |||
| BETA[0], NULL, 0, NULL, 0, \ | |||
| (FLOAT *)(C) + ((M_FROM) + (N_FROM) * (LDC)) * COMPSIZE, LDC); \ | |||
| } while (0) | |||
| #else | |||
| #define BETA_OPERATION(M_FROM, M_TO, N_FROM, N_TO, BETA, C, LDC) \ | |||
| GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \ | |||
| @@ -189,7 +191,11 @@ | |||
| int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, | |||
| XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){ | |||
| BLASLONG k, lda, ldb, ldc; | |||
| #if defined(BUILD_BFLOAT16_ONLY) | |||
| float *alpha, *beta; | |||
| #else | |||
| FLOAT *alpha, *beta; | |||
| #endif | |||
| IFLOAT *a, *b; | |||
| FLOAT *c; | |||
| BLASLONG m_from, m_to, n_from, n_to; | |||
| @@ -224,8 +230,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, | |||
| ldb = LDB; | |||
| ldc = LDC; | |||
| #if defined(BUILD_BFLOAT16_ONLY) | |||
| alpha = (float *)args -> alpha; | |||
| beta = (float *)args -> beta; | |||
| #else | |||
| alpha = (FLOAT *)args -> alpha; | |||
| beta = (FLOAT *)args -> beta; | |||
| #endif | |||
| m_from = 0; | |||
| m_to = M; | |||
| @@ -239,7 +239,11 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, | |||
| BLASLONG k, lda, ldb, ldc; | |||
| BLASLONG m_from, m_to, n_from, n_to; | |||
| #if defined(BUILD_BFLOAT16_ONLY) | |||
| float *alpha, *beta; | |||
| #else | |||
| FLOAT *alpha, *beta; | |||
| #endif | |||
| IFLOAT *a, *b; | |||
| FLOAT *c; | |||
| job_t *job = (job_t *)args -> common; | |||
| @@ -277,8 +281,14 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, | |||
| ldb = LDB; | |||
| ldc = LDC; | |||
| #if defined(BUILD_BFLOAT16_ONLY) | |||
| alpha = (float *)args -> alpha; | |||
| beta = (float *)args -> beta; | |||
| #else | |||
| alpha = (FLOAT *)args -> alpha; | |||
| beta = (FLOAT *)args -> beta; | |||
| #endif | |||
| /* Initialize 2D CPU distribution */ | |||
| nthreads_m = args -> nthreads; | |||
| @@ -53,6 +53,10 @@ SBBLAS3OBJS = sbgemm.$(SUFFIX) sbgemmt.$(SUFFIX) sbgemmtr.$(SUFFIX) | |||
| SBEXTOBJS = sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX) | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||
| BBLAS3OBJ = bgemm.$(SUFFIX) | |||
| endif | |||
| DBLAS1OBJS = \ | |||
| daxpy.$(SUFFIX) dswap.$(SUFFIX) \ | |||
| dcopy.$(SUFFIX) dscal.$(SUFFIX) \ | |||
| @@ -291,6 +295,10 @@ CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) cblas_sbgemmt.$(SUFFIX) cblas_sbgemmtr.$(S | |||
| CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX) | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||
| CBBLAS3OBJS = cblas_bgemm.$(SUFFIX) | |||
| endif | |||
| CDBLAS1OBJS = \ | |||
| cblas_idamax.$(SUFFIX) cblas_idamin.$(SUFFIX) cblas_dasum.$(SUFFIX) cblas_daxpy.$(SUFFIX) \ | |||
| cblas_dcopy.$(SUFFIX) cblas_ddot.$(SUFFIX) \ | |||
| @@ -388,6 +396,7 @@ SBLAS3OBJS += $(CSBLAS3OBJS) | |||
| SBBLAS1OBJS += $(CSBBLAS1OBJS) | |||
| SBBLAS2OBJS += $(CSBBLAS2OBJS) | |||
| SBBLAS3OBJS += $(CSBBLAS3OBJS) | |||
| BBLAS3OBJ += $(CBBLAS3OBJS) | |||
| DBLAS1OBJS += $(CDBLAS1OBJS) | |||
| DBLAS2OBJS += $(CDBLAS2OBJS) | |||
| DBLAS3OBJS += $(CDBLAS3OBJS) | |||
| @@ -403,6 +412,7 @@ SBEXTOBJS += $(CSBEXTOBJS) | |||
| CBAUXOBJS += $(CXERBLAOBJ) | |||
| endif | |||
| BBLASOBJS = $(BBLAS3OBJ) | |||
| SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) | |||
| SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS2OBJS) $(SBBLAS3OBJS) | |||
| DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS) | |||
| @@ -550,7 +560,7 @@ level1 : $(SBEXTOBJS) $(SBBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $ | |||
| level2 : $(SBBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) | |||
| $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | |||
| level3 : $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) | |||
| level3 : $(BBLAS3OBJ) $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) | |||
| $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | |||
| aux : $(CBAUXOBJS) | |||
| @@ -1309,6 +1319,11 @@ sbgemmtr.$(SUFFIX) sbgemmtr.$(PSUFFIX) : sbgemmt.c ../param.h | |||
| $(CC) -c $(CFLAGS) -DRNAME $< -o $(@F) | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||
| bgemm.$(SUFFIX) : gemm.c ../param.h | |||
| $(CC) -c $(CFLAGS) $< -o $(@F) | |||
| endif | |||
| sgemm.$(SUFFIX) sgemm.$(PSUFFIX) : gemm.c ../param.h | |||
| $(CC) -c $(CFLAGS) $< -o $(@F) | |||
| @@ -1968,6 +1983,11 @@ cblas_sbgemm.$(SUFFIX) cblas_sbgemm.$(PSUFFIX) : gemm.c ../param.h | |||
| $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16_ONLY),1) | |||
| cblas_bgemm.$(SUFFIX) cblas_bgemm.$(PSUFFIX) : gemm.c ../param.h | |||
| $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) | |||
| endif | |||
| cblas_dgemm.$(SUFFIX) cblas_dgemm.$(PSUFFIX) : gemm.c ../param.h | |||
| $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) | |||
| @@ -250,6 +250,15 @@ static inline int get_gemm_optimal_nthreads(double MNK) { | |||
| #ifndef CBLAS | |||
| #ifdef BFLOAT16_ONLY | |||
| void NAME(char *TRANSA, char *TRANSB, | |||
| blasint *M, blasint *N, blasint *K, | |||
| float *alpha, | |||
| IFLOAT *a, blasint *ldA, | |||
| IFLOAT *b, blasint *ldB, | |||
| float *beta, | |||
| FLOAT *c, blasint *ldC){ | |||
| #else | |||
| void NAME(char *TRANSA, char *TRANSB, | |||
| blasint *M, blasint *N, blasint *K, | |||
| FLOAT *alpha, | |||
| @@ -257,7 +266,7 @@ void NAME(char *TRANSA, char *TRANSB, | |||
| IFLOAT *b, blasint *ldB, | |||
| FLOAT *beta, | |||
| FLOAT *c, blasint *ldC){ | |||
| #endif | |||
| blas_arg_t args; | |||
| int transa, transb, nrowa, nrowb; | |||
| @@ -366,11 +375,19 @@ void NAME(char *TRANSA, char *TRANSB, | |||
| void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, | |||
| blasint m, blasint n, blasint k, | |||
| #ifndef COMPLEX | |||
| #ifdef BFLOAT16_ONLY | |||
| float alpha, | |||
| IFLOAT *a, blasint lda, | |||
| IFLOAT *b, blasint ldb, | |||
| float beta, | |||
| FLOAT *c, blasint ldc) { | |||
| #else | |||
| FLOAT alpha, | |||
| IFLOAT *a, blasint lda, | |||
| IFLOAT *b, blasint ldb, | |||
| FLOAT beta, | |||
| FLOAT *c, blasint ldc) { | |||
| #endif | |||
| #else | |||
| void *valpha, | |||
| void *va, blasint lda, | |||
| @@ -136,6 +136,25 @@ endif | |||
| endif | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||
| ifndef BGEMMKERNEL | |||
| BGEMM_BETA = ../generic/gemm_beta.c | |||
| BGEMMKERNEL = ../generic/gemmkernel_2x2.c | |||
| BGEMMINCOPY = ../generic/gemm_ncopy_2.c | |||
| BGEMMITCOPY = ../generic/gemm_tcopy_2.c | |||
| BGEMMONCOPY = ../generic/gemm_ncopy_2.c | |||
| BGEMMOTCOPY = ../generic/gemm_tcopy_2.c | |||
| BGEMMINCOPYOBJ = bgemm_incopy$(TSUFFIX).$(SUFFIX) | |||
| BGEMMITCOPYOBJ = bgemm_itcopy$(TSUFFIX).$(SUFFIX) | |||
| BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX) | |||
| BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX) | |||
| endif | |||
| BKERNELOBJS += \ | |||
| bgemm_kernel$(TSUFFIX).$(SUFFIX) \ | |||
| $(BGEMMINCOPYOBJ) $(BGEMMITCOPYOBJ) \ | |||
| $(BGEMMONCOPYOBJ) $(BGEMMOTCOPYOBJ) | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16), 1) | |||
| ifndef SBGEMMKERNEL | |||
| SBGEMM_BETA = ../generic/gemm_beta.c | |||
| @@ -216,6 +235,11 @@ XKERNELOBJS += \ | |||
| $(XGEMMINCOPYOBJ) $(XGEMMITCOPYOBJ) \ | |||
| $(XGEMMONCOPYOBJ) $(XGEMMOTCOPYOBJ) | |||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||
| BBLASOBJS += $(BKERNELOBJS) | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16),1) | |||
| SBBLASOBJS += $(SBKERNELOBJS) | |||
| endif | |||
| @@ -226,6 +250,10 @@ CBLASOBJS += $(CKERNELOBJS) | |||
| ZBLASOBJS += $(ZKERNELOBJS) | |||
| XBLASOBJS += $(XKERNELOBJS) | |||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||
| BBLASOBJS += bgemm_beta$(TSUFFIX).$(SUFFIX) | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16),1) | |||
| SBBLASOBJS += sbgemm_beta$(TSUFFIX).$(SUFFIX) | |||
| endif | |||
| @@ -651,6 +679,11 @@ XGEMMITCOPYOBJ_P = $(XGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | |||
| XGEMMONCOPYOBJ_P = $(XGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | |||
| XGEMMOTCOPYOBJ_P = $(XGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | |||
| ifeq ($(BUILD_BFLOAT16_ONLY),1) | |||
| $(KDIR)bgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BGEMM_BETA) | |||
| $(CC) $(CFLAGS) -c -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16),1) | |||
| $(KDIR)sbgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_BETA) | |||
| $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ | |||
| @@ -678,6 +711,21 @@ ifeq ($(ARCH), E2K) | |||
| USE_TRMM = 1 | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||
| $(KDIR)$(BGEMMONCOPYOBJ) : $(KERNELDIR)/$(BGEMMONCOPY) | |||
| $(CC) $(CFLAGS) -c -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ | |||
| $(KDIR)$(BGEMMOTCOPYOBJ) : $(KERNELDIR)/$(BGEMMOTCOPY) | |||
| $(CC) $(CFLAGS) -c -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ | |||
| $(KDIR)$(BGEMMINCOPYOBJ) : $(KERNELDIR)/$(BGEMMINCOPY) | |||
| $(CC) $(CFLAGS) -c -DDBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ | |||
| $(KDIR)$(BGEMMITCOPYOBJ) : $(KERNELDIR)/$(BGEMMITCOPY) | |||
| $(CC) $(CFLAGS) -c -DDBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16), 1) | |||
| @@ -874,6 +922,11 @@ endif | |||
| endif | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||
| $(KDIR)bgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BGEMMKERNEL) | |||
| $(CC) $(CFLAGS) -c -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16), 1) | |||
| $(KDIR)sbgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMMKERNEL) $(SBGEMMDEPEND) | |||
| @@ -35,7 +35,7 @@ | |||
| #undef ALPHA_ONE | |||
| #include "sbgemm_kernel_4x4_neoversev1_impl.c" | |||
| int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B, | |||
| int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, IFLOAT *A, IFLOAT *B, | |||
| FLOAT *C, BLASLONG ldc) { | |||
| if (alpha == 1.0f) | |||
| return sbgemm_kernel_neoversev1_alpha_one(m, n, k, alpha, A, B, C, ldc); | |||
| @@ -78,11 +78,11 @@ | |||
| #ifdef ALPHA_ONE | |||
| int sbgemm_kernel_neoversev1_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, | |||
| FLOAT alpha, IFLOAT *A, IFLOAT *B, | |||
| float alpha, IFLOAT *A, IFLOAT *B, | |||
| FLOAT *C, BLASLONG ldc) | |||
| #else | |||
| int sbgemm_kernel_neoversev1_alpha(BLASLONG m, BLASLONG n, BLASLONG k, | |||
| FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, | |||
| float alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, | |||
| BLASLONG ldc) | |||
| #endif | |||
| { | |||
| @@ -1,227 +0,0 @@ | |||
| /*************************************************************************** | |||
| * Copyright (c) 2025, The OpenBLAS Project | |||
| * All rights reserved. | |||
| * Redistribution and use in source and binary forms, with or without | |||
| * modification, are permitted provided that the following conditions are | |||
| * met: | |||
| * 1. Redistributions of source code must retain the above copyright | |||
| * notice, this list of conditions and the following disclaimer. | |||
| * 2. Redistributions in binary form must reproduce the above copyright | |||
| * notice, this list of conditions and the following disclaimer in | |||
| * the documentation and/or other materials provided with the | |||
| * distribution. | |||
| * 3. Neither the name of the OpenBLAS project nor the names of | |||
| * its contributors may be used to endorse or promote products | |||
| * derived from this software without specific prior written permission. | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
| * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |||
| * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |||
| * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |||
| * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |||
| * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | |||
| * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |||
| * POSSIBILITY OF SUCH DAMAGE. | |||
| * *****************************************************************************/ | |||
| #include "common.h" | |||
| static float bfloat16tof32(bfloat16 f16) { | |||
| float result = 0; | |||
| unsigned short *q = (unsigned short *)(&result); | |||
| #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ | |||
| q[0] = f16; | |||
| #else | |||
| q[1] = f16; | |||
| #endif | |||
| return result; | |||
| } | |||
| static bfloat16 f32tobfloat16(float f32) { | |||
| unsigned short *q = (unsigned short *)(&f32); | |||
| #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ | |||
| return q[0]; | |||
| #else | |||
| return q[1]; | |||
| #endif | |||
| } | |||
| #define BF16TOF32(x) (bfloat16tof32(x)) | |||
| #define F32TOBF16(x) (f32tobfloat16(x)) | |||
| int CNAME(BLASLONG bm, BLASLONG bn, BLASLONG bk, FLOAT alpha, IFLOAT *ba, | |||
| IFLOAT *bb, FLOAT *C, BLASLONG ldc) { | |||
| BLASLONG i, j, k; | |||
| FLOAT *C0, *C1; // bfloat16 | |||
| IFLOAT *ptrba, *ptrbb; | |||
| float res0, res1, res2, res3; | |||
| float load0, load1, load2, load3, load4, load5, load6, load7; | |||
| float alpha_ = BF16TOF32(alpha); | |||
| for (j = 0; j < bn / 2; j += 1) { | |||
| C0 = C; | |||
| C1 = C0 + ldc; | |||
| ptrba = ba; | |||
| for (i = 0; i < bm / 2; i += 1) { | |||
| ptrbb = bb; | |||
| res0 = 0; | |||
| res1 = 0; | |||
| res2 = 0; | |||
| res3 = 0; | |||
| for (k = 0; k < bk / 4; k += 1) { | |||
| load0 = BF16TOF32(ptrba[2 * 0 + 0]); | |||
| load2 = BF16TOF32(ptrba[2 * 0 + 1]); | |||
| load4 = BF16TOF32(ptrba[2 * 1 + 0]); | |||
| load6 = BF16TOF32(ptrba[2 * 1 + 1]); | |||
| load1 = BF16TOF32(ptrbb[2 * 0 + 0]); | |||
| load3 = BF16TOF32(ptrbb[2 * 0 + 1]); | |||
| load5 = BF16TOF32(ptrbb[2 * 1 + 0]); | |||
| load7 = BF16TOF32(ptrbb[2 * 1 + 1]); | |||
| res0 = res0 + load0 * load1; | |||
| res1 = res1 + load2 * load1; | |||
| res2 = res2 + load0 * load3; | |||
| res3 = res3 + load2 * load3; | |||
| res0 = res0 + load4 * load5; | |||
| res1 = res1 + load6 * load5; | |||
| res2 = res2 + load4 * load7; | |||
| res3 = res3 + load6 * load7; | |||
| load0 = BF16TOF32(ptrba[2 * 2 + 0]); | |||
| load2 = BF16TOF32(ptrba[2 * 2 + 1]); | |||
| load4 = BF16TOF32(ptrba[2 * 3 + 0]); | |||
| load6 = BF16TOF32(ptrba[2 * 3 + 1]); | |||
| load1 = BF16TOF32(ptrbb[2 * 2 + 0]); | |||
| load3 = BF16TOF32(ptrbb[2 * 2 + 1]); | |||
| load5 = BF16TOF32(ptrbb[2 * 3 + 0]); | |||
| load7 = BF16TOF32(ptrbb[2 * 3 + 1]); | |||
| res0 = res0 + load0 * load1; | |||
| res1 = res1 + load2 * load1; | |||
| res2 = res2 + load0 * load3; | |||
| res3 = res3 + load2 * load3; | |||
| res0 = res0 + load4 * load5; | |||
| res1 = res1 + load6 * load5; | |||
| res2 = res2 + load4 * load7; | |||
| res3 = res3 + load6 * load7; | |||
| } | |||
| for (k = 0; k < (bk & 3); k += 1) { | |||
| load0 = BF16TOF32(ptrba[2 * 0 + 0]); | |||
| load2 = BF16TOF32(ptrba[2 * 0 + 1]); | |||
| load1 = BF16TOF32(ptrbb[2 * 0 + 0]); | |||
| load3 = BF16TOF32(ptrbb[2 * 0 + 1]); | |||
| res0 = res0 + load0 * load1; | |||
| res1 = res1 + load2 * load1; | |||
| res2 = res2 + load0 * load3; | |||
| res3 = res3 + load2 * load3; | |||
| ptrba = ptrba + 2; | |||
| ptrbb = ptrbb + 2; | |||
| } | |||
| res0 = res0 * alpha_ + BF16TOF32(C0[0]); | |||
| res1 = res1 * alpha_ + BF16TOF32(C0[1]); | |||
| res2 = res2 * alpha_ + BF16TOF32(C1[0]); | |||
| res3 = res3 * alpha_ + BF16TOF32(C1[1]); | |||
| C0[0] = F32TOBF16(res0); | |||
| C0[1] = F32TOBF16(res1); | |||
| C1[0] = F32TOBF16(res2); | |||
| C1[1] = F32TOBF16(res3); | |||
| C0 = C0 + 2; | |||
| C1 = C1 + 2; | |||
| } | |||
| for (i = 0; i < (bm & 1); i += 1) { | |||
| ptrbb = bb; | |||
| res0 = 0; | |||
| res1 = 0; | |||
| for (k = 0; k < bk; k += 1) { | |||
| load0 = BF16TOF32(ptrba[0 + 0]); | |||
| load1 = BF16TOF32(ptrbb[2 * 0 + 0]); | |||
| load2 = BF16TOF32(ptrbb[2 * 0 + 1]); | |||
| res0 = res0 + load0 * load1; | |||
| res1 = res1 + load0 * load2; | |||
| ptrba = ptrba + 1; | |||
| ptrbb = ptrbb + 2; | |||
| } | |||
| res0 = res0 * alpha_ + BF16TOF32(C0[0]); | |||
| res1 = res1 * alpha_ + BF16TOF32(C1[0]); | |||
| C0[0] = res0; | |||
| C1[0] = res1; | |||
| C0 = C0 + 1; | |||
| C1 = C1 + 1; | |||
| } | |||
| k = (bk << 1); | |||
| bb = bb + k; | |||
| i = (ldc << 1); | |||
| C = C + i; | |||
| } | |||
| for (j = 0; j < (bn & 1); j += 1) { | |||
| C0 = C; | |||
| ptrba = ba; | |||
| for (i = 0; i < bm / 2; i += 1) { | |||
| ptrbb = bb; | |||
| res0 = 0; | |||
| res1 = 0; | |||
| for (k = 0; k < bk; k += 1) { | |||
| load0 = BF16TOF32(ptrba[2 * 0 + 0]); | |||
| load2 = BF16TOF32(ptrba[2 * 0 + 1]); | |||
| load1 = BF16TOF32(ptrbb[0 + 0]); | |||
| res0 = res0 + load0 * load1; | |||
| res1 = res1 + load2 * load1; | |||
| ptrba = ptrba + 2; | |||
| ptrbb = ptrbb + 1; | |||
| } | |||
| res0 = res0 * alpha_ + BF16TOF32(C0[0]); | |||
| res1 = res1 * alpha_ + BF16TOF32(C0[1]); | |||
| C0[0] = F32TOBF16(res0); | |||
| C0[1] = F32TOBF16(res1); | |||
| C0 = C0 + 2; | |||
| } | |||
| for (i = 0; i < (bm & 1); i += 1) { | |||
| ptrbb = bb; | |||
| res0 = 0; | |||
| for (k = 0; k < bk; k += 1) { | |||
| load0 = BF16TOF32(ptrba[0 + 0]); | |||
| load1 = BF16TOF32(ptrbb[0 + 0]); | |||
| res0 += load0 * load1; | |||
| ptrba = ptrba + 1; | |||
| ptrbb = ptrbb + 1; | |||
| } | |||
| res0 = res0 * alpha_ + BF16TOF32(C0[0]); | |||
| C0[0] = F32TOBF16(res0); | |||
| C0 = C0 + 1; | |||
| } | |||
| k = (bk << 0); | |||
| bb = bb + k; | |||
| C = C + ldc; | |||
| } | |||
| return 0; | |||
| } | |||
| @@ -71,11 +71,15 @@ f32tobfloat16(float f32) | |||
| #define F32TOBF16(x) x | |||
| #endif | |||
| int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta_in, | |||
| #if defined(BFLOAT16_ONLY) | |||
| int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, float beta, | |||
| IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, | |||
| FLOAT *c, BLASLONG ldc){ | |||
| #else | |||
| int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta, | |||
| IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, | |||
| FLOAT *c, BLASLONG ldc){ | |||
| #endif | |||
| BLASLONG i, j; | |||
| BLASLONG chunk, remain; | |||
| @@ -83,25 +87,24 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta_in, | |||
| c_offset = c; | |||
| chunk = m >> 3; | |||
| remain = m & 7; | |||
| float beta = BF16TOF32(beta_in); | |||
| if (beta == ZERO){ | |||
| for(j=n; j>0; j--){ | |||
| c_offset1 = c_offset; | |||
| c_offset += ldc; | |||
| for(i=chunk; i>0; i--){ | |||
| *(c_offset1 + 0) = ZERO; | |||
| *(c_offset1 + 1) = ZERO; | |||
| *(c_offset1 + 2) = ZERO; | |||
| *(c_offset1 + 3) = ZERO; | |||
| *(c_offset1 + 4) = ZERO; | |||
| *(c_offset1 + 5) = ZERO; | |||
| *(c_offset1 + 6) = ZERO; | |||
| *(c_offset1 + 7) = ZERO; | |||
| *(c_offset1 + 0) = F32TOBF16(ZERO); | |||
| *(c_offset1 + 1) = F32TOBF16(ZERO); | |||
| *(c_offset1 + 2) = F32TOBF16(ZERO); | |||
| *(c_offset1 + 3) = F32TOBF16(ZERO); | |||
| *(c_offset1 + 4) = F32TOBF16(ZERO); | |||
| *(c_offset1 + 5) = F32TOBF16(ZERO); | |||
| *(c_offset1 + 6) = F32TOBF16(ZERO); | |||
| *(c_offset1 + 7) = F32TOBF16(ZERO); | |||
| c_offset1 += 8; | |||
| } | |||
| for(i=remain; i>0; i--){ | |||
| *c_offset1 = ZERO; | |||
| *c_offset1 = F32TOBF16(ZERO); | |||
| c_offset1 ++; | |||
| } | |||
| } | |||
| @@ -1,5 +1,32 @@ | |||
| /*************************************************************************** | |||
| Copyright (c) 2025 The OpenBLAS Project | |||
| All rights reserved. | |||
| Redistribution and use in source and binary forms, with or without | |||
| modification, are permitted provided that the following conditions are | |||
| met: | |||
| 1. Redistributions of source code must retain the above copyright | |||
| notice, this list of conditions and the following disclaimer. | |||
| 2. Redistributions in binary form must reproduce the above copyright | |||
| notice, this list of conditions and the following disclaimer in | |||
| the documentation and/or other materials provided with the | |||
| distribution. | |||
| 3. Neither the name of the OpenBLAS project nor the names of | |||
| its contributors may be used to endorse or promote products | |||
| derived from this software without specific prior written permission. | |||
| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
| LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |||
| CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE | |||
| GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) | |||
| HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||
| LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF | |||
| THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| *****************************************************************************/ | |||
| #include "common.h" | |||
| #if defined(BFLOAT16) && defined(BFLOAT16CONVERSION) | |||
| #if (defined(BFLOAT16) || defined(BFLOAT16_ONLY))&& defined(BFLOAT16CONVERSION) | |||
| static float | |||
| bfloat16tof32 (bfloat16 f16) | |||
| { | |||
| @@ -12,12 +39,29 @@ bfloat16tof32 (bfloat16 f16) | |||
| #endif | |||
| return result; | |||
| } | |||
| static bfloat16 f32tobfloat16(float f32) { | |||
| unsigned short *q = (unsigned short *)(&f32); | |||
| #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ | |||
| return q[0]; | |||
| #else | |||
| return q[1]; | |||
| #endif | |||
| } | |||
| #define BF16TOF32(x) (bfloat16tof32(x)) | |||
| #define F32TOBF16(x) (f32tobfloat16(x)) | |||
| #else | |||
| #define BF16TOF32(x) x | |||
| #define F32TOBF16(x) x | |||
| #endif | |||
| #ifdef BFLOAT16_ONLY | |||
| int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk, float alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc | |||
| #else | |||
| int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc | |||
| #ifdef TRMMKERNEL | |||
| #endif | |||
| #ifdef TRMMKERNEL | |||
| ,BLASLONG offset | |||
| #endif | |||
| ) | |||
| @@ -90,13 +134,17 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, | |||
| ptrbb = ptrbb+2; | |||
| } | |||
| res0 = res0*alpha; | |||
| C0[0] = C0[0]+res0; | |||
| C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); | |||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[0])); | |||
| res1 = res1*alpha; | |||
| C0[1] = C0[1]+res1; | |||
| C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1); | |||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[1])); | |||
| res2 = res2*alpha; | |||
| C1[0] = C1[0]+res2; | |||
| C1[0] = F32TOBF16(BF16TOF32(C1[0])+res2); | |||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C1[0])); | |||
| res3 = res3*alpha; | |||
| C1[1] = C1[1]+res3; | |||
| C1[1] = F32TOBF16(BF16TOF32(C1[1])+res3); | |||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C1[1])); | |||
| C0 = C0+2; | |||
| C1 = C1+2; | |||
| } | |||
| @@ -116,9 +164,11 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, | |||
| ptrbb = ptrbb+2; | |||
| } | |||
| res0 = res0*alpha; | |||
| C0[0] = C0[0]+res0; | |||
| C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); | |||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[0])); | |||
| res1 = res1*alpha; | |||
| C1[0] = C1[0]+res1; | |||
| C1[0] = F32TOBF16(BF16TOF32(C1[1])+res1); | |||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C1[0])); | |||
| C0 = C0+1; | |||
| C1 = C1+1; | |||
| } | |||
| @@ -147,9 +197,11 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, | |||
| ptrbb = ptrbb+1; | |||
| } | |||
| res0 = res0*alpha; | |||
| C0[0] = C0[0]+res0; | |||
| C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); | |||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[0])); | |||
| res1 = res1*alpha; | |||
| C0[1] = C0[1]+res1; | |||
| C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1); | |||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[1])); | |||
| C0 = C0+2; | |||
| } | |||
| for (i=0; i<(bm&1); i+=1) | |||
| @@ -165,7 +217,8 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, | |||
| ptrbb = ptrbb+1; | |||
| } | |||
| res0 = res0*alpha; | |||
| C0[0] = C0[0]+res0; | |||
| C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); | |||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[0])); | |||
| C0 = C0+1; | |||
| } | |||
| k = (bk<<0); | |||
| @@ -1,3 +1,31 @@ | |||
| ############################################################################### | |||
| # Copyright (c) 2025, The OpenBLAS Project | |||
| # All rights reserved. | |||
| # Redistribution and use in source and binary forms, with or without | |||
| # modification, are permitted provided that the following conditions are | |||
| # met: | |||
| # 1. Redistributions of source code must retain the above copyright | |||
| # notice, this list of conditions and the following disclaimer. | |||
| # 2. Redistributions in binary form must reproduce the above copyright | |||
| # notice, this list of conditions and the following disclaimer in | |||
| # the documentation and/or other materials provided with the | |||
| # distribution. | |||
| # 3. Neither the name of the OpenBLAS project nor the names of | |||
| # its contributors may be used to endorse or promote products | |||
| # derived from this software without specific prior written permission. | |||
| # | |||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| # ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
| # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |||
| # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |||
| # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |||
| # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |||
| # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | |||
| # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |||
| # POSSIBILITY OF SUCH DAMAGE. | |||
| ############################################################################### | |||
| TOPDIR = .. | |||
| include ../Makefile.system | |||
| ifeq ($(F_COMPILER),GFORTRAN) | |||
| @@ -164,6 +192,9 @@ endif | |||
| endif | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16_ONLY),1) | |||
| BF3= test_bgemm | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16),1) | |||
| B3= test_sbgemm | |||
| endif | |||
| @@ -192,11 +223,15 @@ endif | |||
| ifeq ($(SUPPORT_GEMM3M),1) | |||
| level3: $(B3) $(S3) $(D3) $(C3) $(Z3) level3_3m | |||
| else | |||
| level3: $(B3) $(S3) $(D3) $(C3) $(Z3) | |||
| level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3) | |||
| endif | |||
| ifneq ($(CROSS), 1) | |||
| rm -f ?BLAT3.SUMM | |||
| ifeq ($(BUILD_BFLOAT16_ONLY),1) | |||
| OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_bgemm > BBLAT3.SUMM | |||
| @$(GREP) -q FATAL SBBLAT3.SUMM && cat BBLAT3.SUMM || exit 0 | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16),1) | |||
| OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_sbgemm > SBBLAT3.SUMM | |||
| @$(GREP) -q FATAL SBBLAT3.SUMM && cat SBBLAT3.SUMM || exit 0 | |||
| @@ -366,6 +401,11 @@ zblat3 : zblat3.$(SUFFIX) ../$(LIBNAME) | |||
| $(FC) $(FLDFLAGS) -o zblat3 zblat3.$(SUFFIX) ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16_ONLY),1) | |||
| test_bgemm : compare_sgemm_bgemm.c ../$(LIBNAME) | |||
| $(CC) $(CLDFLAGS) -o test_bgemm compare_sgemm_bgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16),1) | |||
| test_sbgemm : compare_sgemm_sbgemm.c ../$(LIBNAME) | |||
| $(CC) $(CLDFLAGS) -o test_sbgemm compare_sgemm_sbgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) | |||
| @@ -387,7 +427,7 @@ clean: | |||
| @rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \ | |||
| sblat1 dblat1 cblat1 zblat1 \ | |||
| sblat2 dblat2 cblat2 zblat2 \ | |||
| test_sbgemm sblat3 dblat3 cblat3 zblat3 \ | |||
| test_bgemm test_sbgemm sblat3 dblat3 cblat3 zblat3 \ | |||
| sblat1p dblat1p cblat1p zblat1p \ | |||
| sblat2p dblat2p cblat2p zblat2p \ | |||
| sblat3p dblat3p cblat3p zblat3p \ | |||
| @@ -57,10 +57,8 @@ int main(int argc, char *argv[]) { | |||
| char transA = 'N', transB = 'N'; | |||
| float alpha = 1.0, beta = 0.0; | |||
| bfloat16 alpha_bf16 = convert_to_bf16(alpha), | |||
| beta_bf16 = convert_to_bf16(beta); | |||
| for (x = 1; x <= BGEMM_LARGEST; x++) { | |||
| for (x = 1; x <= loop; x++) { | |||
| if ((x > 100) && (x != BGEMM_LARGEST)) | |||
| continue; | |||
| m = k = n = x; | |||
| @@ -79,14 +77,14 @@ int main(int argc, char *argv[]) { | |||
| for (int i = 0; i < m; i++) { | |||
| for (int j = 0; j < k; j++) { | |||
| AA[i * k + j] = (i * k + j + 1) % 100; | |||
| AA[i * k + j] = (i * k + j + 1) % 5; | |||
| A[i * k + j] = AA[i * k + j]; | |||
| } | |||
| } | |||
| for (int i = 0; i < n; i++) { | |||
| for (int j = 0; j < k; j++) { | |||
| BB[i * k + j] = (i * k + j + 1) % 100; | |||
| BB[i * k + j] = (i * k + j + 1) % 5; | |||
| B[i * k + j] = BB[i * k + j]; | |||
| } | |||
| } | |||
| @@ -102,23 +100,73 @@ int main(int argc, char *argv[]) { | |||
| } else { | |||
| transB = 'T'; | |||
| } | |||
| // printf("******** x = %d, y = %d********\n", x, y); | |||
| // printf("Matrix AA (m x k):\n"); | |||
| // for (int i = 0; i < m; i++) { | |||
| // for (int j = 0; j < k; j++) { | |||
| // printf("%.2f ", (float)AA[i * k + j]); // or %4.1f if float | |||
| // } | |||
| // printf("\n"); | |||
| // } | |||
| // printf("Matrix A (copy of AA):\n"); | |||
| // for (int i = 0; i < m; i++) { | |||
| // for (int j = 0; j < k; j++) { | |||
| // printf("%.2f ", A[i * k + j]); | |||
| // } | |||
| // printf("\n"); | |||
| // } | |||
| // printf("Matrix BB (n x k):\n"); | |||
| // for (int i = 0; i < n; i++) { | |||
| // for (int j = 0; j < k; j++) { | |||
| // printf("%.2f ", (float)BB[i * k + j]); | |||
| // } | |||
| // printf("\n"); | |||
| // } | |||
| // printf("Matrix B (copy of BB):\n"); | |||
| // for (int i = 0; i < n; i++) { | |||
| // for (int j = 0; j < k; j++) { | |||
| // printf("%.2f ", B[i * k + j]); | |||
| // } | |||
| // printf("\n"); | |||
| // } | |||
| memset(C, 0, m * n * sizeof(FLOAT)); | |||
| memset(CC, 0, m * n * sizeof(bfloat16)); | |||
| SGEMM(&transA, &transB, &m, &n, &k, &alpha_bf16, A, &m, B, &k, &beta_bf16, | |||
| SGEMM(&transA, &transB, &m, &n, &k, &alpha, A, &m, B, &k, &beta, | |||
| C, &m); | |||
| BGEMM(&transA, &transB, &m, &n, &k, &alpha, (bfloat16 *)AA, &m, | |||
| (bfloat16 *)BB, &k, &beta, (bfloat16 *)CC, &m); | |||
| // printf("Matrix CC (n x m):\n"); | |||
| // for (int i = 0; i < n; i++) { | |||
| // for (int j = 0; j < m; j++) { | |||
| // printf("%.2f ", (float)CC[i * m + j]); | |||
| // } | |||
| // printf("\n"); | |||
| // } | |||
| // printf("Matrix C :\n"); | |||
| // for (int i = 0; i < n; i++) { | |||
| // for (int j = 0; j < k; j++) { | |||
| // printf("%.2f ", C[i * k + j]); | |||
| // } | |||
| // printf("\n"); | |||
| // } | |||
| for (i = 0; i < n; i++) { | |||
| for (j = 0; j < m; j++) { | |||
| for (l = 0; l < k; l++) { | |||
| if (fabs(CC[i * m + j] - C[i * m + j]) > 1.0) { | |||
| ret++; | |||
| if (fabs((float)CC[i * m + j] - C[i * m + j]) > 1.0) { | |||
| ret ++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| printf("x = %d, err = %d\n", x, ret); | |||
| ret = 0; | |||
| } | |||
| free(A); | |||