| @@ -251,6 +251,7 @@ In chronological order: | |||||
| * Ye Tao <ye.tao@arm.com> | * Ye Tao <ye.tao@arm.com> | ||||
| * [2025-02-03] Optimize SBGEMM kernel on NEOVERSEV1 | * [2025-02-03] Optimize SBGEMM kernel on NEOVERSEV1 | ||||
| * [2025-02-27] Add sbgemv_n_neon kernel | * [2025-02-27] Add sbgemv_n_neon kernel | ||||
| * [2025-05-17] Impl prototype of BGEMM inferface | |||||
| * Abhishek Kumar <https://github.com/abhishek-iitmadras> | * Abhishek Kumar <https://github.com/abhishek-iitmadras> | ||||
| * [2025-04-22] Optimise dot kernel for NEOVERSE V1 | |||||
| * [2025-04-22] Optimise dot kernel for NEOVERSE V1 | |||||
| @@ -1544,6 +1544,9 @@ ifeq ($(USE_TLS), 1) | |||||
| CCOMMON_OPT += -DUSE_TLS | CCOMMON_OPT += -DUSE_TLS | ||||
| endif | endif | ||||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||||
| CCOMMON_OPT += -DBUILD_BFLOAT16_ONLY | |||||
| endif | |||||
| ifeq ($(BUILD_BFLOAT16), 1) | ifeq ($(BUILD_BFLOAT16), 1) | ||||
| CCOMMON_OPT += -DBUILD_BFLOAT16 | CCOMMON_OPT += -DBUILD_BFLOAT16 | ||||
| endif | endif | ||||
| @@ -1888,6 +1891,7 @@ export FUNCTION_PROFILE | |||||
| export TARGET_CORE | export TARGET_CORE | ||||
| export NO_AVX512 | export NO_AVX512 | ||||
| export NO_AVX2 | export NO_AVX2 | ||||
| export BUILD_BFLOAT16_ONLY | |||||
| export BUILD_BFLOAT16 | export BUILD_BFLOAT16 | ||||
| export NO_LSX | export NO_LSX | ||||
| export NO_LASX | export NO_LASX | ||||
| @@ -1912,7 +1916,7 @@ export ZGEMM3M_UNROLL_M | |||||
| export ZGEMM3M_UNROLL_N | export ZGEMM3M_UNROLL_N | ||||
| export XGEMM3M_UNROLL_M | export XGEMM3M_UNROLL_M | ||||
| export XGEMM3M_UNROLL_N | export XGEMM3M_UNROLL_N | ||||
| # Todo: add bgemm unroll factors | |||||
| ifdef USE_CUDA | ifdef USE_CUDA | ||||
| export CUDADIR | export CUDADIR | ||||
| @@ -11,7 +11,7 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||||
| HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX)) | HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX)) | ||||
| BLASOBJS = $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS) | |||||
| BLASOBJS = $(SBEXTOBJS) $(BBLASOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS) | |||||
| BLASOBJS_P = $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P) | BLASOBJS_P = $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P) | ||||
| ifdef EXPRECISION | ifdef EXPRECISION | ||||
| @@ -24,6 +24,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) | |||||
| BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P) | BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P) | ||||
| endif | endif | ||||
| $(BBLASOBJS) : override CFLAGS += -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX -UBFLOAT16 -USMALL_MATRIX_OPT | |||||
| $(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX | $(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX | ||||
| $(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX | $(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX | ||||
| $(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX | $(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX | ||||
| @@ -42,6 +43,7 @@ $(ZBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | |||||
| $(XBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | $(XBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | ||||
| $(SBEXTOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | $(SBEXTOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | ||||
| libs :: $(BLASOBJS) $(COMMONOBJS) | libs :: $(BLASOBJS) $(COMMONOBJS) | ||||
| $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | ||||
| @@ -475,7 +475,7 @@ void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST en | |||||
| OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); | OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); | ||||
| void cblas_bgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, | void cblas_bgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, | ||||
| OPENBLAS_CONST bfloat16 alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST bfloat16 beta, bfloat16 *C, OPENBLAS_CONST blasint ldc); | |||||
| OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, bfloat16 *C, OPENBLAS_CONST blasint ldc); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| @@ -481,8 +481,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint | |||||
| xdouble *, blasint *, xdouble *, xdouble *, blasint *); | xdouble *, blasint *, xdouble *, xdouble *, blasint *); | ||||
| /* Level 3 routines */ | /* Level 3 routines */ | ||||
| void BLASFUNC(bgemm)(char *, char *, blasint *, blasint *, blasint *, bfloat16 *, | |||||
| bfloat16 *, blasint *, bfloat16 *, blasint *, bfloat16 *, bfloat16 *, blasint *); | |||||
| void BLASFUNC(bgemm)(char *, char *, blasint *, blasint *, blasint *, float *, | |||||
| bfloat16 *, blasint *, bfloat16 *, blasint *, float *, bfloat16 *, blasint *); | |||||
| void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *, | void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *, | ||||
| bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *); | bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *); | ||||
| void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *, | void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *, | ||||
| @@ -54,7 +54,7 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K, | |||||
| int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); | int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); | ||||
| int bgemm_beta(BLASLONG, BLASLONG, BLASLONG, bfloat16, | |||||
| int bgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, | |||||
| bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); | bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); | ||||
| int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, | int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, | ||||
| bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); | bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); | ||||
| @@ -513,7 +513,7 @@ int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl | |||||
| int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); | int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); | ||||
| // add bgemm kernel | // add bgemm kernel | ||||
| int bgemm_kernel(BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG); | |||||
| int bgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG); | |||||
| int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); | int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); | ||||
| int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); | int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); | ||||
| int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG); | int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG); | ||||
| @@ -54,8 +54,8 @@ typedef struct { | |||||
| int bgemm_unroll_m, bgemm_unroll_n, bgemm_unroll_mn; | int bgemm_unroll_m, bgemm_unroll_n, bgemm_unroll_mn; | ||||
| int bgemm_align_k; | int bgemm_align_k; | ||||
| int (*bgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG); | |||||
| int (*bgemm_beta )(BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); | |||||
| int (*bgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG); | |||||
| int (*bgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); | |||||
| int (*bgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); | int (*bgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); | ||||
| int (*bgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); | int (*bgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); | ||||
| @@ -52,6 +52,13 @@ ifeq ($(BUILD_BFLOAT16),1) | |||||
| SBBLASOBJS += sbgemm_nn.$(SUFFIX) sbgemm_nt.$(SUFFIX) sbgemm_tn.$(SUFFIX) sbgemm_tt.$(SUFFIX) | SBBLASOBJS += sbgemm_nn.$(SUFFIX) sbgemm_nt.$(SUFFIX) sbgemm_tn.$(SUFFIX) sbgemm_tt.$(SUFFIX) | ||||
| endif | endif | ||||
| ifeq ($(BUILD_BFLOAT16_ONLY),1) | |||||
| BBLASOBJS += bgemm_nn.$(SUFFIX) bgemm_nt.$(SUFFIX) bgemm_tn.$(SUFFIX) bgemm_tt.$(SUFFIX) | |||||
| endif | |||||
| BLASOBJS += \ | |||||
| gemm_nn.$(SUFFIX) gemm_nt.$(SUFFIX) gemm_tn.$(SUFFIX) gemm_tt.$(SUFFIX) | |||||
| SBLASOBJS += \ | SBLASOBJS += \ | ||||
| sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \ | sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \ | ||||
| strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \ | strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \ | ||||
| @@ -376,6 +383,18 @@ endif | |||||
| all :: | all :: | ||||
| bgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h | |||||
| $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) | |||||
| bgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h | |||||
| $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) | |||||
| bgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h | |||||
| $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) | |||||
| bgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h | |||||
| $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) | |||||
| sbgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h | sbgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h | ||||
| $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) | $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) | ||||
| @@ -432,8 +451,8 @@ cgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h | |||||
| cgemm_nr.$(SUFFIX) : gemm.c level3.c ../../param.h | cgemm_nr.$(SUFFIX) : gemm.c level3.c ../../param.h | ||||
| $(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -DCOMPLEX -DNR $< -o $(@F) | $(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -DCOMPLEX -DNR $< -o $(@F) | ||||
| cgemm_nc.$(SUFFIX) : gemm.c level3.c ../../param.h | cgemm_nc.$(SUFFIX) : gemm.c level3.c ../../param.h | ||||
| $(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -DCOMPLEX -DNC $< -o $(@F) | $(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -DCOMPLEX -DNC $< -o $(@F) | ||||
| cgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h | cgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h | ||||
| @@ -43,9 +43,11 @@ | |||||
| #if !defined(XDOUBLE) || !defined(QUAD_PRECISION) | #if !defined(XDOUBLE) || !defined(QUAD_PRECISION) | ||||
| #ifndef COMPLEX | #ifndef COMPLEX | ||||
| #define BETA_OPERATION(M_FROM, M_TO, N_FROM, N_TO, BETA, C, LDC) \ | #define BETA_OPERATION(M_FROM, M_TO, N_FROM, N_TO, BETA, C, LDC) \ | ||||
| GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \ | |||||
| BETA[0], NULL, 0, NULL, 0, \ | |||||
| (FLOAT *)(C) + ((M_FROM) + (N_FROM) * (LDC)) * COMPSIZE, LDC) | |||||
| do { \ | |||||
| GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \ | |||||
| BETA[0], NULL, 0, NULL, 0, \ | |||||
| (FLOAT *)(C) + ((M_FROM) + (N_FROM) * (LDC)) * COMPSIZE, LDC); \ | |||||
| } while (0) | |||||
| #else | #else | ||||
| #define BETA_OPERATION(M_FROM, M_TO, N_FROM, N_TO, BETA, C, LDC) \ | #define BETA_OPERATION(M_FROM, M_TO, N_FROM, N_TO, BETA, C, LDC) \ | ||||
| GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \ | GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \ | ||||
| @@ -189,7 +191,11 @@ | |||||
| int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, | int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, | ||||
| XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){ | XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){ | ||||
| BLASLONG k, lda, ldb, ldc; | BLASLONG k, lda, ldb, ldc; | ||||
| #if defined(BUILD_BFLOAT16_ONLY) | |||||
| float *alpha, *beta; | |||||
| #else | |||||
| FLOAT *alpha, *beta; | FLOAT *alpha, *beta; | ||||
| #endif | |||||
| IFLOAT *a, *b; | IFLOAT *a, *b; | ||||
| FLOAT *c; | FLOAT *c; | ||||
| BLASLONG m_from, m_to, n_from, n_to; | BLASLONG m_from, m_to, n_from, n_to; | ||||
| @@ -224,8 +230,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, | |||||
| ldb = LDB; | ldb = LDB; | ||||
| ldc = LDC; | ldc = LDC; | ||||
| #if defined(BUILD_BFLOAT16_ONLY) | |||||
| alpha = (float *)args -> alpha; | |||||
| beta = (float *)args -> beta; | |||||
| #else | |||||
| alpha = (FLOAT *)args -> alpha; | alpha = (FLOAT *)args -> alpha; | ||||
| beta = (FLOAT *)args -> beta; | beta = (FLOAT *)args -> beta; | ||||
| #endif | |||||
| m_from = 0; | m_from = 0; | ||||
| m_to = M; | m_to = M; | ||||
| @@ -239,7 +239,11 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, | |||||
| BLASLONG k, lda, ldb, ldc; | BLASLONG k, lda, ldb, ldc; | ||||
| BLASLONG m_from, m_to, n_from, n_to; | BLASLONG m_from, m_to, n_from, n_to; | ||||
| #if defined(BUILD_BFLOAT16_ONLY) | |||||
| float *alpha, *beta; | |||||
| #else | |||||
| FLOAT *alpha, *beta; | FLOAT *alpha, *beta; | ||||
| #endif | |||||
| IFLOAT *a, *b; | IFLOAT *a, *b; | ||||
| FLOAT *c; | FLOAT *c; | ||||
| job_t *job = (job_t *)args -> common; | job_t *job = (job_t *)args -> common; | ||||
| @@ -277,8 +281,14 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, | |||||
| ldb = LDB; | ldb = LDB; | ||||
| ldc = LDC; | ldc = LDC; | ||||
| #if defined(BUILD_BFLOAT16_ONLY) | |||||
| alpha = (float *)args -> alpha; | |||||
| beta = (float *)args -> beta; | |||||
| #else | |||||
| alpha = (FLOAT *)args -> alpha; | alpha = (FLOAT *)args -> alpha; | ||||
| beta = (FLOAT *)args -> beta; | beta = (FLOAT *)args -> beta; | ||||
| #endif | |||||
| /* Initialize 2D CPU distribution */ | /* Initialize 2D CPU distribution */ | ||||
| nthreads_m = args -> nthreads; | nthreads_m = args -> nthreads; | ||||
| @@ -53,6 +53,10 @@ SBBLAS3OBJS = sbgemm.$(SUFFIX) sbgemmt.$(SUFFIX) sbgemmtr.$(SUFFIX) | |||||
| SBEXTOBJS = sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX) | SBEXTOBJS = sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX) | ||||
| endif | endif | ||||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||||
| BBLAS3OBJ = bgemm.$(SUFFIX) | |||||
| endif | |||||
| DBLAS1OBJS = \ | DBLAS1OBJS = \ | ||||
| daxpy.$(SUFFIX) dswap.$(SUFFIX) \ | daxpy.$(SUFFIX) dswap.$(SUFFIX) \ | ||||
| dcopy.$(SUFFIX) dscal.$(SUFFIX) \ | dcopy.$(SUFFIX) dscal.$(SUFFIX) \ | ||||
| @@ -291,6 +295,10 @@ CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) cblas_sbgemmt.$(SUFFIX) cblas_sbgemmtr.$(S | |||||
| CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX) | CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX) | ||||
| endif | endif | ||||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||||
| CBBLAS3OBJS = cblas_bgemm.$(SUFFIX) | |||||
| endif | |||||
| CDBLAS1OBJS = \ | CDBLAS1OBJS = \ | ||||
| cblas_idamax.$(SUFFIX) cblas_idamin.$(SUFFIX) cblas_dasum.$(SUFFIX) cblas_daxpy.$(SUFFIX) \ | cblas_idamax.$(SUFFIX) cblas_idamin.$(SUFFIX) cblas_dasum.$(SUFFIX) cblas_daxpy.$(SUFFIX) \ | ||||
| cblas_dcopy.$(SUFFIX) cblas_ddot.$(SUFFIX) \ | cblas_dcopy.$(SUFFIX) cblas_ddot.$(SUFFIX) \ | ||||
| @@ -388,6 +396,7 @@ SBLAS3OBJS += $(CSBLAS3OBJS) | |||||
| SBBLAS1OBJS += $(CSBBLAS1OBJS) | SBBLAS1OBJS += $(CSBBLAS1OBJS) | ||||
| SBBLAS2OBJS += $(CSBBLAS2OBJS) | SBBLAS2OBJS += $(CSBBLAS2OBJS) | ||||
| SBBLAS3OBJS += $(CSBBLAS3OBJS) | SBBLAS3OBJS += $(CSBBLAS3OBJS) | ||||
| BBLAS3OBJ += $(CBBLAS3OBJS) | |||||
| DBLAS1OBJS += $(CDBLAS1OBJS) | DBLAS1OBJS += $(CDBLAS1OBJS) | ||||
| DBLAS2OBJS += $(CDBLAS2OBJS) | DBLAS2OBJS += $(CDBLAS2OBJS) | ||||
| DBLAS3OBJS += $(CDBLAS3OBJS) | DBLAS3OBJS += $(CDBLAS3OBJS) | ||||
| @@ -403,6 +412,7 @@ SBEXTOBJS += $(CSBEXTOBJS) | |||||
| CBAUXOBJS += $(CXERBLAOBJ) | CBAUXOBJS += $(CXERBLAOBJ) | ||||
| endif | endif | ||||
| BBLASOBJS = $(BBLAS3OBJ) | |||||
| SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) | SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) | ||||
| SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS2OBJS) $(SBBLAS3OBJS) | SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS2OBJS) $(SBBLAS3OBJS) | ||||
| DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS) | DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS) | ||||
| @@ -550,7 +560,7 @@ level1 : $(SBEXTOBJS) $(SBBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $ | |||||
| level2 : $(SBBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) | level2 : $(SBBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) | ||||
| $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | ||||
| level3 : $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) | |||||
| level3 : $(BBLAS3OBJ) $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) | |||||
| $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | ||||
| aux : $(CBAUXOBJS) | aux : $(CBAUXOBJS) | ||||
| @@ -1309,6 +1319,11 @@ sbgemmtr.$(SUFFIX) sbgemmtr.$(PSUFFIX) : sbgemmt.c ../param.h | |||||
| $(CC) -c $(CFLAGS) -DRNAME $< -o $(@F) | $(CC) -c $(CFLAGS) -DRNAME $< -o $(@F) | ||||
| endif | endif | ||||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||||
| bgemm.$(SUFFIX) : gemm.c ../param.h | |||||
| $(CC) -c $(CFLAGS) $< -o $(@F) | |||||
| endif | |||||
| sgemm.$(SUFFIX) sgemm.$(PSUFFIX) : gemm.c ../param.h | sgemm.$(SUFFIX) sgemm.$(PSUFFIX) : gemm.c ../param.h | ||||
| $(CC) -c $(CFLAGS) $< -o $(@F) | $(CC) -c $(CFLAGS) $< -o $(@F) | ||||
| @@ -1968,6 +1983,11 @@ cblas_sbgemm.$(SUFFIX) cblas_sbgemm.$(PSUFFIX) : gemm.c ../param.h | |||||
| $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) | $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) | ||||
| endif | endif | ||||
| ifeq ($(BUILD_BFLOAT16_ONLY),1) | |||||
| cblas_bgemm.$(SUFFIX) cblas_bgemm.$(PSUFFIX) : gemm.c ../param.h | |||||
| $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) | |||||
| endif | |||||
| cblas_dgemm.$(SUFFIX) cblas_dgemm.$(PSUFFIX) : gemm.c ../param.h | cblas_dgemm.$(SUFFIX) cblas_dgemm.$(PSUFFIX) : gemm.c ../param.h | ||||
| $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) | $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) | ||||
| @@ -250,6 +250,15 @@ static inline int get_gemm_optimal_nthreads(double MNK) { | |||||
| #ifndef CBLAS | #ifndef CBLAS | ||||
| #ifdef BFLOAT16_ONLY | |||||
| void NAME(char *TRANSA, char *TRANSB, | |||||
| blasint *M, blasint *N, blasint *K, | |||||
| float *alpha, | |||||
| IFLOAT *a, blasint *ldA, | |||||
| IFLOAT *b, blasint *ldB, | |||||
| float *beta, | |||||
| FLOAT *c, blasint *ldC){ | |||||
| #else | |||||
| void NAME(char *TRANSA, char *TRANSB, | void NAME(char *TRANSA, char *TRANSB, | ||||
| blasint *M, blasint *N, blasint *K, | blasint *M, blasint *N, blasint *K, | ||||
| FLOAT *alpha, | FLOAT *alpha, | ||||
| @@ -257,7 +266,7 @@ void NAME(char *TRANSA, char *TRANSB, | |||||
| IFLOAT *b, blasint *ldB, | IFLOAT *b, blasint *ldB, | ||||
| FLOAT *beta, | FLOAT *beta, | ||||
| FLOAT *c, blasint *ldC){ | FLOAT *c, blasint *ldC){ | ||||
| #endif | |||||
| blas_arg_t args; | blas_arg_t args; | ||||
| int transa, transb, nrowa, nrowb; | int transa, transb, nrowa, nrowb; | ||||
| @@ -366,11 +375,19 @@ void NAME(char *TRANSA, char *TRANSB, | |||||
| void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, | void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, | ||||
| blasint m, blasint n, blasint k, | blasint m, blasint n, blasint k, | ||||
| #ifndef COMPLEX | #ifndef COMPLEX | ||||
| #ifdef BFLOAT16_ONLY | |||||
| float alpha, | |||||
| IFLOAT *a, blasint lda, | |||||
| IFLOAT *b, blasint ldb, | |||||
| float beta, | |||||
| FLOAT *c, blasint ldc) { | |||||
| #else | |||||
| FLOAT alpha, | FLOAT alpha, | ||||
| IFLOAT *a, blasint lda, | IFLOAT *a, blasint lda, | ||||
| IFLOAT *b, blasint ldb, | IFLOAT *b, blasint ldb, | ||||
| FLOAT beta, | FLOAT beta, | ||||
| FLOAT *c, blasint ldc) { | FLOAT *c, blasint ldc) { | ||||
| #endif | |||||
| #else | #else | ||||
| void *valpha, | void *valpha, | ||||
| void *va, blasint lda, | void *va, blasint lda, | ||||
| @@ -136,6 +136,25 @@ endif | |||||
| endif | endif | ||||
| endif | endif | ||||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||||
| ifndef BGEMMKERNEL | |||||
| BGEMM_BETA = ../generic/gemm_beta.c | |||||
| BGEMMKERNEL = ../generic/gemmkernel_2x2.c | |||||
| BGEMMINCOPY = ../generic/gemm_ncopy_2.c | |||||
| BGEMMITCOPY = ../generic/gemm_tcopy_2.c | |||||
| BGEMMONCOPY = ../generic/gemm_ncopy_2.c | |||||
| BGEMMOTCOPY = ../generic/gemm_tcopy_2.c | |||||
| BGEMMINCOPYOBJ = bgemm_incopy$(TSUFFIX).$(SUFFIX) | |||||
| BGEMMITCOPYOBJ = bgemm_itcopy$(TSUFFIX).$(SUFFIX) | |||||
| BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX) | |||||
| BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX) | |||||
| endif | |||||
| BKERNELOBJS += \ | |||||
| bgemm_kernel$(TSUFFIX).$(SUFFIX) \ | |||||
| $(BGEMMINCOPYOBJ) $(BGEMMITCOPYOBJ) \ | |||||
| $(BGEMMONCOPYOBJ) $(BGEMMOTCOPYOBJ) | |||||
| endif | |||||
| ifeq ($(BUILD_BFLOAT16), 1) | ifeq ($(BUILD_BFLOAT16), 1) | ||||
| ifndef SBGEMMKERNEL | ifndef SBGEMMKERNEL | ||||
| SBGEMM_BETA = ../generic/gemm_beta.c | SBGEMM_BETA = ../generic/gemm_beta.c | ||||
| @@ -216,6 +235,11 @@ XKERNELOBJS += \ | |||||
| $(XGEMMINCOPYOBJ) $(XGEMMITCOPYOBJ) \ | $(XGEMMINCOPYOBJ) $(XGEMMITCOPYOBJ) \ | ||||
| $(XGEMMONCOPYOBJ) $(XGEMMOTCOPYOBJ) | $(XGEMMONCOPYOBJ) $(XGEMMOTCOPYOBJ) | ||||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||||
| BBLASOBJS += $(BKERNELOBJS) | |||||
| endif | |||||
| ifeq ($(BUILD_BFLOAT16),1) | ifeq ($(BUILD_BFLOAT16),1) | ||||
| SBBLASOBJS += $(SBKERNELOBJS) | SBBLASOBJS += $(SBKERNELOBJS) | ||||
| endif | endif | ||||
| @@ -226,6 +250,10 @@ CBLASOBJS += $(CKERNELOBJS) | |||||
| ZBLASOBJS += $(ZKERNELOBJS) | ZBLASOBJS += $(ZKERNELOBJS) | ||||
| XBLASOBJS += $(XKERNELOBJS) | XBLASOBJS += $(XKERNELOBJS) | ||||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||||
| BBLASOBJS += bgemm_beta$(TSUFFIX).$(SUFFIX) | |||||
| endif | |||||
| ifeq ($(BUILD_BFLOAT16),1) | ifeq ($(BUILD_BFLOAT16),1) | ||||
| SBBLASOBJS += sbgemm_beta$(TSUFFIX).$(SUFFIX) | SBBLASOBJS += sbgemm_beta$(TSUFFIX).$(SUFFIX) | ||||
| endif | endif | ||||
| @@ -651,6 +679,11 @@ XGEMMITCOPYOBJ_P = $(XGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | |||||
| XGEMMONCOPYOBJ_P = $(XGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | XGEMMONCOPYOBJ_P = $(XGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | ||||
| XGEMMOTCOPYOBJ_P = $(XGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | XGEMMOTCOPYOBJ_P = $(XGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | ||||
| ifeq ($(BUILD_BFLOAT16_ONLY),1) | |||||
| $(KDIR)bgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BGEMM_BETA) | |||||
| $(CC) $(CFLAGS) -c -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ | |||||
| endif | |||||
| ifeq ($(BUILD_BFLOAT16),1) | ifeq ($(BUILD_BFLOAT16),1) | ||||
| $(KDIR)sbgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_BETA) | $(KDIR)sbgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_BETA) | ||||
| $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ | $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ | ||||
| @@ -678,6 +711,21 @@ ifeq ($(ARCH), E2K) | |||||
| USE_TRMM = 1 | USE_TRMM = 1 | ||||
| endif | endif | ||||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||||
| $(KDIR)$(BGEMMONCOPYOBJ) : $(KERNELDIR)/$(BGEMMONCOPY) | |||||
| $(CC) $(CFLAGS) -c -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ | |||||
| $(KDIR)$(BGEMMOTCOPYOBJ) : $(KERNELDIR)/$(BGEMMOTCOPY) | |||||
| $(CC) $(CFLAGS) -c -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ | |||||
| $(KDIR)$(BGEMMINCOPYOBJ) : $(KERNELDIR)/$(BGEMMINCOPY) | |||||
| $(CC) $(CFLAGS) -c -DDBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ | |||||
| $(KDIR)$(BGEMMITCOPYOBJ) : $(KERNELDIR)/$(BGEMMITCOPY) | |||||
| $(CC) $(CFLAGS) -c -DDBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ | |||||
| endif | |||||
| ifeq ($(BUILD_BFLOAT16), 1) | ifeq ($(BUILD_BFLOAT16), 1) | ||||
| @@ -874,6 +922,11 @@ endif | |||||
| endif | endif | ||||
| endif | endif | ||||
| ifeq ($(BUILD_BFLOAT16_ONLY), 1) | |||||
| $(KDIR)bgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BGEMMKERNEL) | |||||
| $(CC) $(CFLAGS) -c -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ | |||||
| endif | |||||
| ifeq ($(BUILD_BFLOAT16), 1) | ifeq ($(BUILD_BFLOAT16), 1) | ||||
| $(KDIR)sbgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMMKERNEL) $(SBGEMMDEPEND) | $(KDIR)sbgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMMKERNEL) $(SBGEMMDEPEND) | ||||
| @@ -35,7 +35,7 @@ | |||||
| #undef ALPHA_ONE | #undef ALPHA_ONE | ||||
| #include "sbgemm_kernel_4x4_neoversev1_impl.c" | #include "sbgemm_kernel_4x4_neoversev1_impl.c" | ||||
| int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B, | |||||
| int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, IFLOAT *A, IFLOAT *B, | |||||
| FLOAT *C, BLASLONG ldc) { | FLOAT *C, BLASLONG ldc) { | ||||
| if (alpha == 1.0f) | if (alpha == 1.0f) | ||||
| return sbgemm_kernel_neoversev1_alpha_one(m, n, k, alpha, A, B, C, ldc); | return sbgemm_kernel_neoversev1_alpha_one(m, n, k, alpha, A, B, C, ldc); | ||||
| @@ -78,11 +78,11 @@ | |||||
| #ifdef ALPHA_ONE | #ifdef ALPHA_ONE | ||||
| int sbgemm_kernel_neoversev1_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, | int sbgemm_kernel_neoversev1_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, | ||||
| FLOAT alpha, IFLOAT *A, IFLOAT *B, | |||||
| float alpha, IFLOAT *A, IFLOAT *B, | |||||
| FLOAT *C, BLASLONG ldc) | FLOAT *C, BLASLONG ldc) | ||||
| #else | #else | ||||
| int sbgemm_kernel_neoversev1_alpha(BLASLONG m, BLASLONG n, BLASLONG k, | int sbgemm_kernel_neoversev1_alpha(BLASLONG m, BLASLONG n, BLASLONG k, | ||||
| FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, | |||||
| float alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, | |||||
| BLASLONG ldc) | BLASLONG ldc) | ||||
| #endif | #endif | ||||
| { | { | ||||
| @@ -1,227 +0,0 @@ | |||||
| /*************************************************************************** | |||||
| * Copyright (c) 2025, The OpenBLAS Project | |||||
| * All rights reserved. | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| * modification, are permitted provided that the following conditions are | |||||
| * met: | |||||
| * 1. Redistributions of source code must retain the above copyright | |||||
| * notice, this list of conditions and the following disclaimer. | |||||
| * 2. Redistributions in binary form must reproduce the above copyright | |||||
| * notice, this list of conditions and the following disclaimer in | |||||
| * the documentation and/or other materials provided with the | |||||
| * distribution. | |||||
| * 3. Neither the name of the OpenBLAS project nor the names of | |||||
| * its contributors may be used to endorse or promote products | |||||
| * derived from this software without specific prior written permission. | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
| * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
| * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |||||
| * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |||||
| * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |||||
| * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |||||
| * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | |||||
| * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |||||
| * POSSIBILITY OF SUCH DAMAGE. | |||||
| * *****************************************************************************/ | |||||
| #include "common.h" | |||||
| static float bfloat16tof32(bfloat16 f16) { | |||||
| float result = 0; | |||||
| unsigned short *q = (unsigned short *)(&result); | |||||
| #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ | |||||
| q[0] = f16; | |||||
| #else | |||||
| q[1] = f16; | |||||
| #endif | |||||
| return result; | |||||
| } | |||||
| static bfloat16 f32tobfloat16(float f32) { | |||||
| unsigned short *q = (unsigned short *)(&f32); | |||||
| #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ | |||||
| return q[0]; | |||||
| #else | |||||
| return q[1]; | |||||
| #endif | |||||
| } | |||||
| #define BF16TOF32(x) (bfloat16tof32(x)) | |||||
| #define F32TOBF16(x) (f32tobfloat16(x)) | |||||
| int CNAME(BLASLONG bm, BLASLONG bn, BLASLONG bk, FLOAT alpha, IFLOAT *ba, | |||||
| IFLOAT *bb, FLOAT *C, BLASLONG ldc) { | |||||
| BLASLONG i, j, k; | |||||
| FLOAT *C0, *C1; // bfloat16 | |||||
| IFLOAT *ptrba, *ptrbb; | |||||
| float res0, res1, res2, res3; | |||||
| float load0, load1, load2, load3, load4, load5, load6, load7; | |||||
| float alpha_ = BF16TOF32(alpha); | |||||
| for (j = 0; j < bn / 2; j += 1) { | |||||
| C0 = C; | |||||
| C1 = C0 + ldc; | |||||
| ptrba = ba; | |||||
| for (i = 0; i < bm / 2; i += 1) { | |||||
| ptrbb = bb; | |||||
| res0 = 0; | |||||
| res1 = 0; | |||||
| res2 = 0; | |||||
| res3 = 0; | |||||
| for (k = 0; k < bk / 4; k += 1) { | |||||
| load0 = BF16TOF32(ptrba[2 * 0 + 0]); | |||||
| load2 = BF16TOF32(ptrba[2 * 0 + 1]); | |||||
| load4 = BF16TOF32(ptrba[2 * 1 + 0]); | |||||
| load6 = BF16TOF32(ptrba[2 * 1 + 1]); | |||||
| load1 = BF16TOF32(ptrbb[2 * 0 + 0]); | |||||
| load3 = BF16TOF32(ptrbb[2 * 0 + 1]); | |||||
| load5 = BF16TOF32(ptrbb[2 * 1 + 0]); | |||||
| load7 = BF16TOF32(ptrbb[2 * 1 + 1]); | |||||
| res0 = res0 + load0 * load1; | |||||
| res1 = res1 + load2 * load1; | |||||
| res2 = res2 + load0 * load3; | |||||
| res3 = res3 + load2 * load3; | |||||
| res0 = res0 + load4 * load5; | |||||
| res1 = res1 + load6 * load5; | |||||
| res2 = res2 + load4 * load7; | |||||
| res3 = res3 + load6 * load7; | |||||
| load0 = BF16TOF32(ptrba[2 * 2 + 0]); | |||||
| load2 = BF16TOF32(ptrba[2 * 2 + 1]); | |||||
| load4 = BF16TOF32(ptrba[2 * 3 + 0]); | |||||
| load6 = BF16TOF32(ptrba[2 * 3 + 1]); | |||||
| load1 = BF16TOF32(ptrbb[2 * 2 + 0]); | |||||
| load3 = BF16TOF32(ptrbb[2 * 2 + 1]); | |||||
| load5 = BF16TOF32(ptrbb[2 * 3 + 0]); | |||||
| load7 = BF16TOF32(ptrbb[2 * 3 + 1]); | |||||
| res0 = res0 + load0 * load1; | |||||
| res1 = res1 + load2 * load1; | |||||
| res2 = res2 + load0 * load3; | |||||
| res3 = res3 + load2 * load3; | |||||
| res0 = res0 + load4 * load5; | |||||
| res1 = res1 + load6 * load5; | |||||
| res2 = res2 + load4 * load7; | |||||
| res3 = res3 + load6 * load7; | |||||
| } | |||||
| for (k = 0; k < (bk & 3); k += 1) { | |||||
| load0 = BF16TOF32(ptrba[2 * 0 + 0]); | |||||
| load2 = BF16TOF32(ptrba[2 * 0 + 1]); | |||||
| load1 = BF16TOF32(ptrbb[2 * 0 + 0]); | |||||
| load3 = BF16TOF32(ptrbb[2 * 0 + 1]); | |||||
| res0 = res0 + load0 * load1; | |||||
| res1 = res1 + load2 * load1; | |||||
| res2 = res2 + load0 * load3; | |||||
| res3 = res3 + load2 * load3; | |||||
| ptrba = ptrba + 2; | |||||
| ptrbb = ptrbb + 2; | |||||
| } | |||||
| res0 = res0 * alpha_ + BF16TOF32(C0[0]); | |||||
| res1 = res1 * alpha_ + BF16TOF32(C0[1]); | |||||
| res2 = res2 * alpha_ + BF16TOF32(C1[0]); | |||||
| res3 = res3 * alpha_ + BF16TOF32(C1[1]); | |||||
| C0[0] = F32TOBF16(res0); | |||||
| C0[1] = F32TOBF16(res1); | |||||
| C1[0] = F32TOBF16(res2); | |||||
| C1[1] = F32TOBF16(res3); | |||||
| C0 = C0 + 2; | |||||
| C1 = C1 + 2; | |||||
| } | |||||
| for (i = 0; i < (bm & 1); i += 1) { | |||||
| ptrbb = bb; | |||||
| res0 = 0; | |||||
| res1 = 0; | |||||
| for (k = 0; k < bk; k += 1) { | |||||
| load0 = BF16TOF32(ptrba[0 + 0]); | |||||
| load1 = BF16TOF32(ptrbb[2 * 0 + 0]); | |||||
| load2 = BF16TOF32(ptrbb[2 * 0 + 1]); | |||||
| res0 = res0 + load0 * load1; | |||||
| res1 = res1 + load0 * load2; | |||||
| ptrba = ptrba + 1; | |||||
| ptrbb = ptrbb + 2; | |||||
| } | |||||
| res0 = res0 * alpha_ + BF16TOF32(C0[0]); | |||||
| res1 = res1 * alpha_ + BF16TOF32(C1[0]); | |||||
| C0[0] = res0; | |||||
| C1[0] = res1; | |||||
| C0 = C0 + 1; | |||||
| C1 = C1 + 1; | |||||
| } | |||||
| k = (bk << 1); | |||||
| bb = bb + k; | |||||
| i = (ldc << 1); | |||||
| C = C + i; | |||||
| } | |||||
| for (j = 0; j < (bn & 1); j += 1) { | |||||
| C0 = C; | |||||
| ptrba = ba; | |||||
| for (i = 0; i < bm / 2; i += 1) { | |||||
| ptrbb = bb; | |||||
| res0 = 0; | |||||
| res1 = 0; | |||||
| for (k = 0; k < bk; k += 1) { | |||||
| load0 = BF16TOF32(ptrba[2 * 0 + 0]); | |||||
| load2 = BF16TOF32(ptrba[2 * 0 + 1]); | |||||
| load1 = BF16TOF32(ptrbb[0 + 0]); | |||||
| res0 = res0 + load0 * load1; | |||||
| res1 = res1 + load2 * load1; | |||||
| ptrba = ptrba + 2; | |||||
| ptrbb = ptrbb + 1; | |||||
| } | |||||
| res0 = res0 * alpha_ + BF16TOF32(C0[0]); | |||||
| res1 = res1 * alpha_ + BF16TOF32(C0[1]); | |||||
| C0[0] = F32TOBF16(res0); | |||||
| C0[1] = F32TOBF16(res1); | |||||
| C0 = C0 + 2; | |||||
| } | |||||
| for (i = 0; i < (bm & 1); i += 1) { | |||||
| ptrbb = bb; | |||||
| res0 = 0; | |||||
| for (k = 0; k < bk; k += 1) { | |||||
| load0 = BF16TOF32(ptrba[0 + 0]); | |||||
| load1 = BF16TOF32(ptrbb[0 + 0]); | |||||
| res0 += load0 * load1; | |||||
| ptrba = ptrba + 1; | |||||
| ptrbb = ptrbb + 1; | |||||
| } | |||||
| res0 = res0 * alpha_ + BF16TOF32(C0[0]); | |||||
| C0[0] = F32TOBF16(res0); | |||||
| C0 = C0 + 1; | |||||
| } | |||||
| k = (bk << 0); | |||||
| bb = bb + k; | |||||
| C = C + ldc; | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| @@ -71,11 +71,15 @@ f32tobfloat16(float f32) | |||||
| #define F32TOBF16(x) x | #define F32TOBF16(x) x | ||||
| #endif | #endif | ||||
| int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta_in, | |||||
| #if defined(BFLOAT16_ONLY) | |||||
| int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, float beta, | |||||
| IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, | IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, | ||||
| FLOAT *c, BLASLONG ldc){ | FLOAT *c, BLASLONG ldc){ | ||||
| #else | |||||
| int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta, | |||||
| IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, | |||||
| FLOAT *c, BLASLONG ldc){ | |||||
| #endif | |||||
| BLASLONG i, j; | BLASLONG i, j; | ||||
| BLASLONG chunk, remain; | BLASLONG chunk, remain; | ||||
| @@ -83,25 +87,24 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta_in, | |||||
| c_offset = c; | c_offset = c; | ||||
| chunk = m >> 3; | chunk = m >> 3; | ||||
| remain = m & 7; | remain = m & 7; | ||||
| float beta = BF16TOF32(beta_in); | |||||
| if (beta == ZERO){ | if (beta == ZERO){ | ||||
| for(j=n; j>0; j--){ | for(j=n; j>0; j--){ | ||||
| c_offset1 = c_offset; | c_offset1 = c_offset; | ||||
| c_offset += ldc; | c_offset += ldc; | ||||
| for(i=chunk; i>0; i--){ | for(i=chunk; i>0; i--){ | ||||
| *(c_offset1 + 0) = ZERO; | |||||
| *(c_offset1 + 1) = ZERO; | |||||
| *(c_offset1 + 2) = ZERO; | |||||
| *(c_offset1 + 3) = ZERO; | |||||
| *(c_offset1 + 4) = ZERO; | |||||
| *(c_offset1 + 5) = ZERO; | |||||
| *(c_offset1 + 6) = ZERO; | |||||
| *(c_offset1 + 7) = ZERO; | |||||
| *(c_offset1 + 0) = F32TOBF16(ZERO); | |||||
| *(c_offset1 + 1) = F32TOBF16(ZERO); | |||||
| *(c_offset1 + 2) = F32TOBF16(ZERO); | |||||
| *(c_offset1 + 3) = F32TOBF16(ZERO); | |||||
| *(c_offset1 + 4) = F32TOBF16(ZERO); | |||||
| *(c_offset1 + 5) = F32TOBF16(ZERO); | |||||
| *(c_offset1 + 6) = F32TOBF16(ZERO); | |||||
| *(c_offset1 + 7) = F32TOBF16(ZERO); | |||||
| c_offset1 += 8; | c_offset1 += 8; | ||||
| } | } | ||||
| for(i=remain; i>0; i--){ | for(i=remain; i>0; i--){ | ||||
| *c_offset1 = ZERO; | |||||
| *c_offset1 = F32TOBF16(ZERO); | |||||
| c_offset1 ++; | c_offset1 ++; | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,5 +1,32 @@ | |||||
| /*************************************************************************** | |||||
| Copyright (c) 2025 The OpenBLAS Project | |||||
| All rights reserved. | |||||
| Redistribution and use in source and binary forms, with or without | |||||
| modification, are permitted provided that the following conditions are | |||||
| met: | |||||
| 1. Redistributions of source code must retain the above copyright | |||||
| notice, this list of conditions and the following disclaimer. | |||||
| 2. Redistributions in binary form must reproduce the above copyright | |||||
| notice, this list of conditions and the following disclaimer in | |||||
| the documentation and/or other materials provided with the | |||||
| distribution. | |||||
| 3. Neither the name of the OpenBLAS project nor the names of | |||||
| its contributors may be used to endorse or promote products | |||||
| derived from this software without specific prior written permission. | |||||
| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
| ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
| LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |||||
| CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE | |||||
| GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) | |||||
| HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||||
| LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF | |||||
| THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| *****************************************************************************/ | |||||
| #include "common.h" | #include "common.h" | ||||
| #if defined(BFLOAT16) && defined(BFLOAT16CONVERSION) | |||||
| #if (defined(BFLOAT16) || defined(BFLOAT16_ONLY))&& defined(BFLOAT16CONVERSION) | |||||
| static float | static float | ||||
| bfloat16tof32 (bfloat16 f16) | bfloat16tof32 (bfloat16 f16) | ||||
| { | { | ||||
| @@ -12,12 +39,29 @@ bfloat16tof32 (bfloat16 f16) | |||||
| #endif | #endif | ||||
| return result; | return result; | ||||
| } | } | ||||
| static bfloat16 f32tobfloat16(float f32) { | |||||
| unsigned short *q = (unsigned short *)(&f32); | |||||
| #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ | |||||
| return q[0]; | |||||
| #else | |||||
| return q[1]; | |||||
| #endif | |||||
| } | |||||
| #define BF16TOF32(x) (bfloat16tof32(x)) | #define BF16TOF32(x) (bfloat16tof32(x)) | ||||
| #define F32TOBF16(x) (f32tobfloat16(x)) | |||||
| #else | #else | ||||
| #define BF16TOF32(x) x | #define BF16TOF32(x) x | ||||
| #define F32TOBF16(x) x | |||||
| #endif | #endif | ||||
| #ifdef BFLOAT16_ONLY | |||||
| int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk, float alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc | |||||
| #else | |||||
| int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc | int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc | ||||
| #ifdef TRMMKERNEL | |||||
| #endif | |||||
| #ifdef TRMMKERNEL | |||||
| ,BLASLONG offset | ,BLASLONG offset | ||||
| #endif | #endif | ||||
| ) | ) | ||||
| @@ -90,13 +134,17 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, | |||||
| ptrbb = ptrbb+2; | ptrbb = ptrbb+2; | ||||
| } | } | ||||
| res0 = res0*alpha; | res0 = res0*alpha; | ||||
| C0[0] = C0[0]+res0; | |||||
| C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); | |||||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[0])); | |||||
| res1 = res1*alpha; | res1 = res1*alpha; | ||||
| C0[1] = C0[1]+res1; | |||||
| C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1); | |||||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[1])); | |||||
| res2 = res2*alpha; | res2 = res2*alpha; | ||||
| C1[0] = C1[0]+res2; | |||||
| C1[0] = F32TOBF16(BF16TOF32(C1[0])+res2); | |||||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C1[0])); | |||||
| res3 = res3*alpha; | res3 = res3*alpha; | ||||
| C1[1] = C1[1]+res3; | |||||
| C1[1] = F32TOBF16(BF16TOF32(C1[1])+res3); | |||||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C1[1])); | |||||
| C0 = C0+2; | C0 = C0+2; | ||||
| C1 = C1+2; | C1 = C1+2; | ||||
| } | } | ||||
| @@ -116,9 +164,11 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, | |||||
| ptrbb = ptrbb+2; | ptrbb = ptrbb+2; | ||||
| } | } | ||||
| res0 = res0*alpha; | res0 = res0*alpha; | ||||
| C0[0] = C0[0]+res0; | |||||
| C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); | |||||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[0])); | |||||
| res1 = res1*alpha; | res1 = res1*alpha; | ||||
| C1[0] = C1[0]+res1; | |||||
| C1[0] = F32TOBF16(BF16TOF32(C1[1])+res1); | |||||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C1[0])); | |||||
| C0 = C0+1; | C0 = C0+1; | ||||
| C1 = C1+1; | C1 = C1+1; | ||||
| } | } | ||||
| @@ -147,9 +197,11 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, | |||||
| ptrbb = ptrbb+1; | ptrbb = ptrbb+1; | ||||
| } | } | ||||
| res0 = res0*alpha; | res0 = res0*alpha; | ||||
| C0[0] = C0[0]+res0; | |||||
| C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); | |||||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[0])); | |||||
| res1 = res1*alpha; | res1 = res1*alpha; | ||||
| C0[1] = C0[1]+res1; | |||||
| C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1); | |||||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[1])); | |||||
| C0 = C0+2; | C0 = C0+2; | ||||
| } | } | ||||
| for (i=0; i<(bm&1); i+=1) | for (i=0; i<(bm&1); i+=1) | ||||
| @@ -165,7 +217,8 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, | |||||
| ptrbb = ptrbb+1; | ptrbb = ptrbb+1; | ||||
| } | } | ||||
| res0 = res0*alpha; | res0 = res0*alpha; | ||||
| C0[0] = C0[0]+res0; | |||||
| C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); | |||||
| //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[0])); | |||||
| C0 = C0+1; | C0 = C0+1; | ||||
| } | } | ||||
| k = (bk<<0); | k = (bk<<0); | ||||
| @@ -1,3 +1,31 @@ | |||||
| ############################################################################### | |||||
| # Copyright (c) 2025, The OpenBLAS Project | |||||
| # All rights reserved. | |||||
| # Redistribution and use in source and binary forms, with or without | |||||
| # modification, are permitted provided that the following conditions are | |||||
| # met: | |||||
| # 1. Redistributions of source code must retain the above copyright | |||||
| # notice, this list of conditions and the following disclaimer. | |||||
| # 2. Redistributions in binary form must reproduce the above copyright | |||||
| # notice, this list of conditions and the following disclaimer in | |||||
| # the documentation and/or other materials provided with the | |||||
| # distribution. | |||||
| # 3. Neither the name of the OpenBLAS project nor the names of | |||||
| # its contributors may be used to endorse or promote products | |||||
| # derived from this software without specific prior written permission. | |||||
| # | |||||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
| # ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
| # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |||||
| # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |||||
| # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |||||
| # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |||||
| # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | |||||
| # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |||||
| # POSSIBILITY OF SUCH DAMAGE. | |||||
| ############################################################################### | |||||
| TOPDIR = .. | TOPDIR = .. | ||||
| include ../Makefile.system | include ../Makefile.system | ||||
| ifeq ($(F_COMPILER),GFORTRAN) | ifeq ($(F_COMPILER),GFORTRAN) | ||||
| @@ -164,6 +192,9 @@ endif | |||||
| endif | endif | ||||
| endif | endif | ||||
| ifeq ($(BUILD_BFLOAT16_ONLY),1) | |||||
| BF3= test_bgemm | |||||
| endif | |||||
| ifeq ($(BUILD_BFLOAT16),1) | ifeq ($(BUILD_BFLOAT16),1) | ||||
| B3= test_sbgemm | B3= test_sbgemm | ||||
| endif | endif | ||||
| @@ -192,11 +223,15 @@ endif | |||||
| ifeq ($(SUPPORT_GEMM3M),1) | ifeq ($(SUPPORT_GEMM3M),1) | ||||
| level3: $(B3) $(S3) $(D3) $(C3) $(Z3) level3_3m | level3: $(B3) $(S3) $(D3) $(C3) $(Z3) level3_3m | ||||
| else | else | ||||
| level3: $(B3) $(S3) $(D3) $(C3) $(Z3) | |||||
| level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3) | |||||
| endif | endif | ||||
| ifneq ($(CROSS), 1) | ifneq ($(CROSS), 1) | ||||
| rm -f ?BLAT3.SUMM | rm -f ?BLAT3.SUMM | ||||
| ifeq ($(BUILD_BFLOAT16_ONLY),1) | |||||
| OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_bgemm > BBLAT3.SUMM | |||||
| @$(GREP) -q FATAL SBBLAT3.SUMM && cat BBLAT3.SUMM || exit 0 | |||||
| endif | |||||
| ifeq ($(BUILD_BFLOAT16),1) | ifeq ($(BUILD_BFLOAT16),1) | ||||
| OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_sbgemm > SBBLAT3.SUMM | OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_sbgemm > SBBLAT3.SUMM | ||||
| @$(GREP) -q FATAL SBBLAT3.SUMM && cat SBBLAT3.SUMM || exit 0 | @$(GREP) -q FATAL SBBLAT3.SUMM && cat SBBLAT3.SUMM || exit 0 | ||||
| @@ -366,6 +401,11 @@ zblat3 : zblat3.$(SUFFIX) ../$(LIBNAME) | |||||
| $(FC) $(FLDFLAGS) -o zblat3 zblat3.$(SUFFIX) ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) | $(FC) $(FLDFLAGS) -o zblat3 zblat3.$(SUFFIX) ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) | ||||
| endif | endif | ||||
| ifeq ($(BUILD_BFLOAT16_ONLY),1) | |||||
| test_bgemm : compare_sgemm_bgemm.c ../$(LIBNAME) | |||||
| $(CC) $(CLDFLAGS) -o test_bgemm compare_sgemm_bgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) | |||||
| endif | |||||
| ifeq ($(BUILD_BFLOAT16),1) | ifeq ($(BUILD_BFLOAT16),1) | ||||
| test_sbgemm : compare_sgemm_sbgemm.c ../$(LIBNAME) | test_sbgemm : compare_sgemm_sbgemm.c ../$(LIBNAME) | ||||
| $(CC) $(CLDFLAGS) -o test_sbgemm compare_sgemm_sbgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) | $(CC) $(CLDFLAGS) -o test_sbgemm compare_sgemm_sbgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) | ||||
| @@ -387,7 +427,7 @@ clean: | |||||
| @rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \ | @rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \ | ||||
| sblat1 dblat1 cblat1 zblat1 \ | sblat1 dblat1 cblat1 zblat1 \ | ||||
| sblat2 dblat2 cblat2 zblat2 \ | sblat2 dblat2 cblat2 zblat2 \ | ||||
| test_sbgemm sblat3 dblat3 cblat3 zblat3 \ | |||||
| test_bgemm test_sbgemm sblat3 dblat3 cblat3 zblat3 \ | |||||
| sblat1p dblat1p cblat1p zblat1p \ | sblat1p dblat1p cblat1p zblat1p \ | ||||
| sblat2p dblat2p cblat2p zblat2p \ | sblat2p dblat2p cblat2p zblat2p \ | ||||
| sblat3p dblat3p cblat3p zblat3p \ | sblat3p dblat3p cblat3p zblat3p \ | ||||
| @@ -57,10 +57,8 @@ int main(int argc, char *argv[]) { | |||||
| char transA = 'N', transB = 'N'; | char transA = 'N', transB = 'N'; | ||||
| float alpha = 1.0, beta = 0.0; | float alpha = 1.0, beta = 0.0; | ||||
| bfloat16 alpha_bf16 = convert_to_bf16(alpha), | |||||
| beta_bf16 = convert_to_bf16(beta); | |||||
| for (x = 1; x <= BGEMM_LARGEST; x++) { | |||||
| for (x = 1; x <= loop; x++) { | |||||
| if ((x > 100) && (x != BGEMM_LARGEST)) | if ((x > 100) && (x != BGEMM_LARGEST)) | ||||
| continue; | continue; | ||||
| m = k = n = x; | m = k = n = x; | ||||
| @@ -79,14 +77,14 @@ int main(int argc, char *argv[]) { | |||||
| for (int i = 0; i < m; i++) { | for (int i = 0; i < m; i++) { | ||||
| for (int j = 0; j < k; j++) { | for (int j = 0; j < k; j++) { | ||||
| AA[i * k + j] = (i * k + j + 1) % 100; | |||||
| AA[i * k + j] = (i * k + j + 1) % 5; | |||||
| A[i * k + j] = AA[i * k + j]; | A[i * k + j] = AA[i * k + j]; | ||||
| } | } | ||||
| } | } | ||||
| for (int i = 0; i < n; i++) { | for (int i = 0; i < n; i++) { | ||||
| for (int j = 0; j < k; j++) { | for (int j = 0; j < k; j++) { | ||||
| BB[i * k + j] = (i * k + j + 1) % 100; | |||||
| BB[i * k + j] = (i * k + j + 1) % 5; | |||||
| B[i * k + j] = BB[i * k + j]; | B[i * k + j] = BB[i * k + j]; | ||||
| } | } | ||||
| } | } | ||||
| @@ -102,23 +100,73 @@ int main(int argc, char *argv[]) { | |||||
| } else { | } else { | ||||
| transB = 'T'; | transB = 'T'; | ||||
| } | } | ||||
| // printf("******** x = %d, y = %d********\n", x, y); | |||||
| // printf("Matrix AA (m x k):\n"); | |||||
| // for (int i = 0; i < m; i++) { | |||||
| // for (int j = 0; j < k; j++) { | |||||
| // printf("%.2f ", (float)AA[i * k + j]); // or %4.1f if float | |||||
| // } | |||||
| // printf("\n"); | |||||
| // } | |||||
| // printf("Matrix A (copy of AA):\n"); | |||||
| // for (int i = 0; i < m; i++) { | |||||
| // for (int j = 0; j < k; j++) { | |||||
| // printf("%.2f ", A[i * k + j]); | |||||
| // } | |||||
| // printf("\n"); | |||||
| // } | |||||
| // printf("Matrix BB (n x k):\n"); | |||||
| // for (int i = 0; i < n; i++) { | |||||
| // for (int j = 0; j < k; j++) { | |||||
| // printf("%.2f ", (float)BB[i * k + j]); | |||||
| // } | |||||
| // printf("\n"); | |||||
| // } | |||||
| // printf("Matrix B (copy of BB):\n"); | |||||
| // for (int i = 0; i < n; i++) { | |||||
| // for (int j = 0; j < k; j++) { | |||||
| // printf("%.2f ", B[i * k + j]); | |||||
| // } | |||||
| // printf("\n"); | |||||
| // } | |||||
| memset(C, 0, m * n * sizeof(FLOAT)); | memset(C, 0, m * n * sizeof(FLOAT)); | ||||
| memset(CC, 0, m * n * sizeof(bfloat16)); | memset(CC, 0, m * n * sizeof(bfloat16)); | ||||
| SGEMM(&transA, &transB, &m, &n, &k, &alpha_bf16, A, &m, B, &k, &beta_bf16, | |||||
| SGEMM(&transA, &transB, &m, &n, &k, &alpha, A, &m, B, &k, &beta, | |||||
| C, &m); | C, &m); | ||||
| BGEMM(&transA, &transB, &m, &n, &k, &alpha, (bfloat16 *)AA, &m, | BGEMM(&transA, &transB, &m, &n, &k, &alpha, (bfloat16 *)AA, &m, | ||||
| (bfloat16 *)BB, &k, &beta, (bfloat16 *)CC, &m); | (bfloat16 *)BB, &k, &beta, (bfloat16 *)CC, &m); | ||||
| // printf("Matrix CC (n x m):\n"); | |||||
| // for (int i = 0; i < n; i++) { | |||||
| // for (int j = 0; j < m; j++) { | |||||
| // printf("%.2f ", (float)CC[i * m + j]); | |||||
| // } | |||||
| // printf("\n"); | |||||
| // } | |||||
| // printf("Matrix C :\n"); | |||||
| // for (int i = 0; i < n; i++) { | |||||
| // for (int j = 0; j < k; j++) { | |||||
| // printf("%.2f ", C[i * k + j]); | |||||
| // } | |||||
| // printf("\n"); | |||||
| // } | |||||
| for (i = 0; i < n; i++) { | for (i = 0; i < n; i++) { | ||||
| for (j = 0; j < m; j++) { | for (j = 0; j < m; j++) { | ||||
| for (l = 0; l < k; l++) { | |||||
| if (fabs(CC[i * m + j] - C[i * m + j]) > 1.0) { | |||||
| ret++; | |||||
| if (fabs((float)CC[i * m + j] - C[i * m + j]) > 1.0) { | |||||
| ret ++; | |||||
| } | } | ||||
| } | |||||
| } | } | ||||
| } | } | ||||
| printf("x = %d, err = %d\n", x, ret); | |||||
| ret = 0; | |||||
| } | } | ||||
| free(A); | free(A); | ||||