diff --git a/benchmark/Makefile b/benchmark/Makefile index 9316921e1..cdf87c0ab 100644 --- a/benchmark/Makefile +++ b/benchmark/Makefile @@ -1,3 +1,31 @@ +############################################################################### +# Copyright (c) 2025, The OpenBLAS Project +# All rights reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. +# 3. Neither the name of the OpenBLAS project nor the names of +# its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +############################################################################### + TOPDIR = .. include $(TOPDIR)/Makefile.system @@ -56,7 +84,7 @@ GOTO_LAPACK_TARGETS= endif ifeq ($(BUILD_BFLOAT16),1) -GOTO_BFLOAT_TARGETS=sbgemm.goto +GOTO_BFLOAT_TARGETS=bgemm.goto sbgemm.goto else GOTO_BFLOAT_TARGETS= endif @@ -635,6 +663,8 @@ zcholesky.essl : zcholesky.$(SUFFIX) ##################################### Sgemm #################################################### ifeq ($(BUILD_BFLOAT16),1) +bgemm.goto : bgemm.$(SUFFIX) ../$(LIBNAME) + $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm sbgemm.goto : sbgemm.$(SUFFIX) ../$(LIBNAME) $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm endif @@ -2970,6 +3000,8 @@ zcholesky.$(SUFFIX) : cholesky.c $(CC) $(CFLAGS) -c -DCOMPLEX -DDOUBLE -o $(@F) $^ ifeq ($(BUILD_BFLOAT16),1) +bgemm.$(SUFFIX) : gemm.c + $(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UCOMPLEX -UDOUBLE -o $(@F) $^ sbgemm.$(SUFFIX) : gemm.c $(CC) $(CFLAGS) -c -DBFLOAT16 -UCOMPLEX -UDOUBLE -o $(@F) $^ endif diff --git a/benchmark/gemm.c b/benchmark/gemm.c index a138bfe1e..704e33225 100644 --- a/benchmark/gemm.c +++ b/benchmark/gemm.c @@ -33,6 +33,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifdef DOUBLE #define GEMM BLASFUNC(dgemm) +#elif defined(BFLOAT16) && defined(BGEMM) +#define GEMM BLASFUNC(bgemm) #elif defined(BFLOAT16) #define GEMM BLASFUNC(sbgemm) #undef IFLOAT @@ -60,8 +62,18 @@ int main(int argc, char *argv[]){ IFLOAT *a, *b; FLOAT *c; +#ifdef BGEMM + blasint one=1; + blasint two=2; + float alpha_in[] = {1.0, 0.0}; + float beta_in[] = {0.0, 0.0}; + FLOAT alpha[2], beta[2]; + sbstobf16_(&two, alpha_in, &one, alpha, &one); + sbstobf16_(&two, beta_in, &one, beta, &one); +#else FLOAT alpha[] = {1.0, 0.0}; FLOAT beta [] = {0.0, 0.0}; +#endif char transa = 'N'; char transb = 'N'; blasint m, n, k, i, j, lda, ldb, ldc; diff --git a/common_b.h b/common_b.h index e03f6800d..4d77ec4fa 100644 --- a/common_b.h +++ b/common_b.h @@ -30,10 +30,16 @@ #define COMMON_B_H #ifndef DYNAMIC_ARCH -#define BGEMM_ONCOPY bgemm_oncopy -#define BGEMM_OTCOPY bgemm_otcopy -#define BGEMM_INCOPY bgemm_incopy -#define BGEMM_ITCOPY bgemm_itcopy +#define BGEMM_ONCOPY bgemm_oncopy +#define BGEMM_OTCOPY bgemm_otcopy + +#if BGEMM_DEFAULT_UNROLL_M == BGEMM_DEFAULT_UNROLL_N +#define BGEMM_INCOPY bgemm_oncopy +#define BGEMM_ITCOPY bgemm_otcopy +#else +#define BGEMM_INCOPY bgemm_incopy +#define BGEMM_ITCOPY bgemm_itcopy +#endif #define BGEMM_BETA bgemm_beta #define BGEMM_KERNEL bgemm_kernel diff --git a/interface/Makefile b/interface/Makefile index e14796cbb..3af12748f 100644 --- a/interface/Makefile +++ b/interface/Makefile @@ -1,3 +1,31 @@ +############################################################################### +# Copyright (c) 2025, The OpenBLAS Project +# All rights reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. +# 3. Neither the name of the OpenBLAS project nor the names of +# its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +############################################################################### + TOPDIR = .. include $(TOPDIR)/Makefile.system @@ -526,7 +554,7 @@ ifneq ($(BUILD_COMPLEX16),1) ZBLASOBJS= endif -FUNCOBJS = $(SBEXTOBJS) $(CXERBLAOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(SHBLASOBJS) +FUNCOBJS = $(SBEXTOBJS) $(CXERBLAOBJS) $(BBLASOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(SHBLASOBJS) ifeq ($(EXPRECISION), 1) FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS) diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 06f18e6be..f45ccc42d 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -674,6 +674,10 @@ ZBLASOBJS += \ endif ifeq ($(BUILD_BFLOAT16), 1) +BGEMMINCOPYOBJ_P = $(BGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) +BGEMMITCOPYOBJ_P = $(BGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) +BGEMMONCOPYOBJ_P = $(BGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) +BGEMMOTCOPYOBJ_P = $(BGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SBGEMMINCOPYOBJ_P = $(SBGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SBGEMMITCOPYOBJ_P = $(SBGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SBGEMMONCOPYOBJ_P = $(SBGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) @@ -2998,6 +3002,20 @@ $(KDIR)xgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(XGEMM_BETA) ifeq ($(BUILD_BFLOAT16), 1) +$(BGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(BGEMMONCOPY) + $(CC) $(PFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ + +$(BGEMMOTCOPYOBJ_P) : $(KERNELDIR)/$(BGEMMOTCOPY) + $(CC) $(PFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ + +ifneq ($(BGEMM_UNROLL_M), $(BGEMM_UNROLL_N)) +$(BGEMMINCOPYOBJ_P) : $(KERNELDIR)/$(BGEMMINCOPY) + $(CC) $(PFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ + +$(BGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(BGEMMITCOPY) + $(CC) $(PFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ +endif + $(SBGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SBGEMMONCOPY) $(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ @@ -3010,7 +3028,6 @@ $(SBGEMMINCOPYOBJ_P) : $(KERNELDIR)/$(SBGEMMINCOPY) $(SBGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(SBGEMMITCOPY) $(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ - endif endif @@ -3137,6 +3154,8 @@ endif ifeq ($(BUILD_BFLOAT16), 1) +$(KDIR)bgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(BGEMMKERNEL) $(BGEMMDEPEND) + $(CC) $(PFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ $(KDIR)sbgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SBGEMMKERNEL) $(SBGEMMDEPEND) $(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ endif diff --git a/kernel/arm64/KERNEL.NEOVERSEV1 b/kernel/arm64/KERNEL.NEOVERSEV1 index 3e622bcbf..8bc0f35e5 100644 --- a/kernel/arm64/KERNEL.NEOVERSEV1 +++ b/kernel/arm64/KERNEL.NEOVERSEV1 @@ -1,3 +1,31 @@ +############################################################################### +# Copyright (c) 2025, The OpenBLAS Project +# All rights reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. +# 3. Neither the name of the OpenBLAS project nor the names of +# its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +############################################################################### + include $(KERNELDIR)/KERNEL.ARMV8SVE SGEMVNKERNEL = gemv_n_sve_v1x3.c @@ -5,6 +33,19 @@ DGEMVNKERNEL = gemv_n_sve_v1x3.c SGEMVTKERNEL = gemv_t_sve_v1x3.c DGEMVTKERNEL = gemv_t_sve_v1x3.c ifeq ($(BUILD_BFLOAT16), 1) +BGEMM_BETA = bgemm_beta_neon.c +BGEMMKERNEL = bgemm_kernel_$(BGEMM_UNROLL_M)x$(BGEMM_UNROLL_N)_neoversev1.c +ifneq ($(BGEMM_UNROLL_M), $(BGEMM_UNROLL_N)) +BGEMMINCOPY = sbgemm_ncopy_$(SBGEMM_UNROLL_M)_neoversev1.c +BGEMMITCOPY = sbgemm_tcopy_$(SBGEMM_UNROLL_M)_neoversev1.c +BGEMMINCOPYOBJ = bgemm_incopy$(TSUFFIX).$(SUFFIX) +BGEMMITCOPYOBJ = bgemm_itcopy$(TSUFFIX).$(SUFFIX) +endif +BGEMMONCOPY = sbgemm_ncopy_$(BGEMM_UNROLL_N)_neoversev1.c +BGEMMOTCOPY = sbgemm_tcopy_$(BGEMM_UNROLL_N)_neoversev1.c +BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX) +BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX) + SBGEMM_BETA = sbgemm_beta_neoversev1.c SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversev1.c ifneq ($(SBGEMM_UNROLL_M), $(SBGEMM_UNROLL_N)) @@ -21,4 +62,4 @@ SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX) SBGEMVNKERNEL = sbgemv_n_neon.c SBGEMVTKERNEL = sbgemv_t_bfdot.c -endif \ No newline at end of file +endif diff --git a/kernel/arm64/bgemm_beta_neon.c b/kernel/arm64/bgemm_beta_neon.c new file mode 100644 index 000000000..603377f8f --- /dev/null +++ b/kernel/arm64/bgemm_beta_neon.c @@ -0,0 +1,107 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include "common.h" + +#include + +int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta_in, IFLOAT *dummy2, + BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, FLOAT *c, + BLASLONG ldc) { + BLASLONG i, j; + BLASLONG chunk, remain; + + bfloat16_t *ptr_c, *ptr_c0; + + bfloat16x8_t x0, z0; + float32x4_t y0, y1; + + float x; + bfloat16_t z; + + bfloat16_t zero_bf16 = vcvth_bf16_f32(0.0f); + bfloat16x8_t zeros = vdupq_n_bf16(zero_bf16); + + bfloat16_t beta_bf16; + memcpy(&beta_bf16, &beta_in, sizeof(bfloat16_t)); + float beta = vcvtah_f32_bf16(beta_bf16); + float32x4_t beta_neon = vdupq_n_f32(beta); + + ptr_c = (bfloat16_t *)c; + + chunk = m >> 3; + remain = m & 7; + + if (beta == 0.0f){ + for (j = 0; j < n; j ++){ + ptr_c0 = ptr_c; + ptr_c += ldc; + + for (i = 0; i < chunk; i ++){ + vst1q_bf16(ptr_c0, zeros); + ptr_c0 += 8; + } + + for (i = 0; i < remain; i ++){ + ptr_c0[0] = zero_bf16; + ptr_c0 ++; + } + } + } else { + for (j = 0; j < n; j ++){ + ptr_c0 = ptr_c; + ptr_c += ldc; + + for (i = 0; i < chunk; i ++){ + x0 = vld1q_bf16(ptr_c0); + + y0 = vcvtq_low_f32_bf16(x0); + y1 = vcvtq_high_f32_bf16(x0); + + y0 = vmulq_f32(y0, beta_neon); + y1 = vmulq_f32(y1, beta_neon); + + z0 = vcvtq_low_bf16_f32(y0); + z0 = vcvtq_high_bf16_f32(z0, y1); + + vst1q_bf16(ptr_c0, z0); + + ptr_c0 += 8; + } + + for (i = 0; i < remain; i ++){ + x = vcvtah_f32_bf16(ptr_c0[0]); + z = vcvth_bf16_f32(x * beta); + + ptr_c0[0] = z; + ptr_c0 ++; + } + } + } + return 0; +}; diff --git a/kernel/arm64/bgemm_kernel_4x4_neoversev1.c b/kernel/arm64/bgemm_kernel_4x4_neoversev1.c new file mode 100644 index 000000000..7af31bb2c --- /dev/null +++ b/kernel/arm64/bgemm_kernel_4x4_neoversev1.c @@ -0,0 +1,50 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include + +#include "common.h" + +#define ALPHA_ONE +#include "bgemm_kernel_4x4_neoversev1_impl.c" +#undef ALPHA_ONE +#undef UPDATE_C +#include "bgemm_kernel_4x4_neoversev1_impl.c" + +int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B, + FLOAT *C, BLASLONG ldc) { + bfloat16_t alpha_bf16; + memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t)); + float alpha_f32 = vcvtah_f32_bf16(alpha_bf16); + + if (alpha_f32 == 1.0f) + return bgemm_kernel_neoversev1_alpha_one(m, n, k, alpha, A, B, C, ldc); + else + return bgemm_kernel_neoversev1_alpha(m, n, k, alpha, A, B, C, ldc); + return 0; +} diff --git a/kernel/arm64/bgemm_kernel_4x4_neoversev1_impl.c b/kernel/arm64/bgemm_kernel_4x4_neoversev1_impl.c new file mode 100644 index 000000000..2477da9c0 --- /dev/null +++ b/kernel/arm64/bgemm_kernel_4x4_neoversev1_impl.c @@ -0,0 +1,429 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include + +#include "common.h" + +#define INIT_C(M, N) mc##M##N = svdup_f32(0); + +#define MATMUL(M, N) mc##M##N = svbfmmla(mc##M##N, ma##M, mb##N); + +#define INIT_C_4x4 \ + do { \ + INIT_C(0, 0); \ + INIT_C(0, 1); \ + INIT_C(1, 0); \ + INIT_C(1, 1); \ + } while (0); + +#ifdef ALPHA_ONE +#define UPDATE_C(PG16, PG32, PTR, TMP32, TMP16, SRC32) \ + do { \ + TMP32 = svreinterpret_f32_u32(svld1uh_u32((PG16), (uint16_t*)PTR)); \ + TMP32 = svadd_z((PG32), SRC32, TMP32); \ + TMP16 = svcvt_bf16_f32_z((PG32), TMP32); \ + TMP16 = svuzp1_bf16(TMP16, TMP16); \ + svst1_bf16((PG16), (PTR), TMP16); \ + } while (0) +#else +#define UPDATE_C(PG16, PG32, PTR, TMP32, TMP16, SRC32) \ + do { \ + TMP32 = svreinterpret_f32_u32(svld1uh_u32((PG16), (uint16_t*)PTR)); \ + TMP32 = svmad_z((PG32), svalpha, SRC32, TMP32); \ + TMP16 = svcvt_bf16_f32_z((PG32), TMP32); \ + TMP16 = svuzp1_bf16(TMP16, TMP16); \ + svst1_bf16((PG16), (PTR), TMP16); \ + } while (0) +#endif + +#define ZIP_EVEN_ELEMENTS(PG, mc0, mc1, tmp, vc) \ + do { \ + (tmp) = svuzp1_f32((mc0), (mc1)); \ + (vc) = svcompact_f32((PG), (tmp)); \ + } while (0) + +#define ZIP_ODD_ELEMENTS(PG, mc0, mc1, tmp, vc) \ + do { \ + (tmp) = svuzp2_f32((mc0), (mc1)); \ + (vc) = svcompact_f32((PG), (tmp)); \ + } while (0) + +#define ACCUMULATE_LAST4_TO_FIRST4(M, N, TMP) \ + do { \ + TMP = svext_f32(mc##M##N, mc##M##N, 4); \ + mc##M##N = svadd_f32_z(svptrue_b32(), mc##M##N, (TMP)); \ + } while (0) + +#ifdef ALPHA_ONE +int bgemm_kernel_neoversev1_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, + FLOAT alpha, IFLOAT *A, IFLOAT *B, + FLOAT *C, BLASLONG ldc) +#else +int bgemm_kernel_neoversev1_alpha(BLASLONG m, BLASLONG n, BLASLONG k, + FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, + BLASLONG ldc) +#endif +{ + BLASLONG pad_k = (k + 7) & ~7; + svbfloat16_t ma0, ma1, mb0, mb1; + svfloat32_t mc00, mc01, mc10, mc11, vc0, vc1, vc2, vc3; + svfloat32_t tmp; +#ifndef ALPHA_ONE + bfloat16_t alpha_bf16; + memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t)); + svfloat32_t svalpha = svdup_f32(vcvtah_f32_bf16(alpha_bf16)); +#endif + + svbool_t pg16_all = svptrue_b16(); + + svbool_t pg32_first_1 = svwhilelt_b32(0, 1); + svbool_t pg32_first_2 = svwhilelt_b32(0, 2); + svbool_t pg32_first_4 = svwhilelt_b32(0, 4); + + svbool_t pg16_first_1 = svwhilelt_b16(0, 1); + svbool_t pg16_first_2 = svwhilelt_b16(0, 2); + svbool_t pg16_first_4 = svwhilelt_b16(0, 4); + + svbool_t pg32_select_first_2_per_quadword = svdupq_b32(1, 1, 0, 0); + + bfloat16_t *ptr_a = (bfloat16_t *)A; + bfloat16_t *ptr_b = (bfloat16_t *)B; + bfloat16_t *ptr_c = (bfloat16_t *)C; + + bfloat16_t *ptr_a0; + bfloat16_t *ptr_b0; + bfloat16_t *ptr_c0, *ptr_c1, *ptr_c2, *ptr_c3; + + svfloat32_t tmp32; + svbfloat16_t tmp16; + + for (BLASLONG j = 0; j < n / 4; j++) { + ptr_c0 = ptr_c; + ptr_c1 = ptr_c0 + ldc; + ptr_c2 = ptr_c1 + ldc; + ptr_c3 = ptr_c2 + ldc; + ptr_c += 4 * ldc; + ptr_a = (bfloat16_t *)A; + + for (BLASLONG i = 0; i < m / 4; i++) { + ptr_a0 = ptr_a; + ptr_a += 4 * pad_k; + + ptr_b0 = ptr_b; + + INIT_C_4x4; + + for (BLASLONG p = 0; p < pad_k; p += 8) { + ma0 = svld1_bf16(pg16_all, ptr_a0); + ma1 = svld1_bf16(pg16_all, ptr_a0 + 16); + + mb0 = svld1_bf16(pg16_all, ptr_b0); + mb1 = svld1_bf16(pg16_all, ptr_b0 + 16); + + MATMUL(0, 0); + MATMUL(0, 1); + MATMUL(1, 0); + MATMUL(1, 1); + + ptr_a0 += 32; + ptr_b0 += 32; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + ACCUMULATE_LAST4_TO_FIRST4(0, 1, tmp); + ACCUMULATE_LAST4_TO_FIRST4(1, 0, tmp); + ACCUMULATE_LAST4_TO_FIRST4(1, 1, tmp); + + ZIP_EVEN_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc0); + ZIP_ODD_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc1); + + ZIP_EVEN_ELEMENTS(pg32_select_first_2_per_quadword, mc01, mc11, tmp, vc2); + ZIP_ODD_ELEMENTS(pg32_select_first_2_per_quadword, mc01, mc11, tmp, vc3); + + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, tmp32, tmp16, vc0); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, tmp32, tmp16, vc1); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c2, tmp32, tmp16, vc2); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c3, tmp32, tmp16, vc3); + + ptr_c0 += 4; + ptr_c1 += 4; + ptr_c2 += 4; + ptr_c3 += 4; + } + + if (m & 2) { + ptr_a0 = ptr_a; + ptr_a += 2 * pad_k; + + ptr_b0 = ptr_b; + INIT_C(0, 0); + INIT_C(0, 1); + for (BLASLONG p = 0; p < pad_k; p += 8) { + ma0 = svld1_bf16(pg16_all, ptr_a0); + mb0 = svld1_bf16(pg16_all, ptr_b0); + mb1 = svld1_bf16(pg16_all, ptr_b0 + 16); + + MATMUL(0, 0); + MATMUL(0, 1); + + ptr_a0 += 16; + ptr_b0 += 32; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + ACCUMULATE_LAST4_TO_FIRST4(0, 1, tmp); + + vc0 = svuzp1(mc00, mc00); + vc1 = svuzp2(mc00, mc00); + vc2 = svuzp1(mc01, mc01); + vc3 = svuzp2(mc01, mc01); + + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, tmp32, tmp16, vc0); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c1, tmp32, tmp16, vc1); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c2, tmp32, tmp16, vc2); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c3, tmp32, tmp16, vc3); + + ptr_c0 += 2; + ptr_c1 += 2; + ptr_c2 += 2; + ptr_c3 += 2; + } + + if (m & 1) { + ptr_a0 = ptr_a; + ptr_b0 = ptr_b; + + INIT_C(0, 0); + INIT_C(0, 1); + for (BLASLONG p = 0; p < pad_k; p += 8) { + ma0 = svld1_bf16(pg16_all, ptr_a0); + mb0 = svld1_bf16(pg16_all, ptr_b0); + mb1 = svld1_bf16(pg16_all, ptr_b0 + 16); + + MATMUL(0, 0); + MATMUL(0, 1); + + ptr_a0 += 16; + ptr_b0 += 32; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + ACCUMULATE_LAST4_TO_FIRST4(0, 1, tmp); + + // use compact is more straightforward + vc1 = svuzp2(mc00, mc00); + vc3 = svuzp2(mc01, mc01); + + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, tmp32, tmp16, mc00); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c1, tmp32, tmp16, vc1); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c2, tmp32, tmp16, mc01); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c3, tmp32, tmp16, vc3); + } + + ptr_b += 4 * pad_k; + } + + if (n & 2) { + ptr_c0 = ptr_c; + ptr_c1 = ptr_c0 + ldc; + ptr_c += 2 * ldc; + ptr_a = (bfloat16_t *)A; + + for (BLASLONG i = 0; i < m / 4; i++) { + ptr_a0 = ptr_a; + ptr_a += 4 * pad_k; + + ptr_b0 = ptr_b; + + INIT_C(0, 0); + INIT_C(1, 0); + + for (BLASLONG p = 0; p < pad_k; p += 8) { + ma0 = svld1_bf16(pg16_all, ptr_a0); + ma1 = svld1_bf16(pg16_all, ptr_a0 + 16); + + mb0 = svld1_bf16(pg16_all, ptr_b0); + + MATMUL(0, 0); + MATMUL(1, 0); + + ptr_a0 += 32; + ptr_b0 += 16; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + ACCUMULATE_LAST4_TO_FIRST4(1, 0, tmp); + + ZIP_EVEN_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc0); + ZIP_ODD_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc2); + + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, tmp32, tmp16, vc0); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, tmp32, tmp16, vc2); + + ptr_c0 += 4; + ptr_c1 += 4; + } + + if (m & 2) { + ptr_a0 = ptr_a; + ptr_a += 2 * pad_k; + ptr_b0 = ptr_b; + + INIT_C(0, 0); + + for (BLASLONG p = 0; p < pad_k; p += 8) { + ma0 = svld1_bf16(pg16_all, ptr_a0); + mb0 = svld1_bf16(pg16_all, ptr_b0); + + MATMUL(0, 0); + + ptr_a0 += 16; + ptr_b0 += 16; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + vc0 = svuzp1(mc00, mc00); + vc1 = svuzp2(mc00, mc00); + + + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, tmp32, tmp16, vc0); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c1, tmp32, tmp16, vc1); + + ptr_c0 += 2; + ptr_c1 += 2; + } + + if (m & 1) { + ptr_a0 = ptr_a; + ptr_b0 = ptr_b; + INIT_C(0, 0); + for (BLASLONG p = 0; p < pad_k; p += 8) { + ma0 = svld1_bf16(pg16_all, ptr_a0); + mb0 = svld1_bf16(pg16_all, ptr_b0); + MATMUL(0, 0); + ptr_a0 += 16; + ptr_b0 += 16; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + vc1 = svuzp2(mc00, mc00); + + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, tmp32, tmp16, mc00); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c1, tmp32, tmp16, vc1); + } + + ptr_b += 2 * pad_k; + } + + if (n & 1) { // TODO: this case seems a overhead. find out whether it's in our + // case. + ptr_c0 = ptr_c; + ptr_a = (bfloat16_t *)A; + + for (BLASLONG i = 0; i < m / 4; i++) { + ptr_a0 = ptr_a; + ptr_a += 4 * pad_k; + + ptr_b0 = ptr_b; + + INIT_C(0, 0); + INIT_C(1, 0); + + for (BLASLONG p = 0; p < pad_k; p += 8) { + ma0 = svld1_bf16(pg16_all, ptr_a0); + ma1 = svld1_bf16(pg16_all, ptr_a0 + 16); + + mb0 = svld1_bf16(pg16_all, ptr_b0); + + MATMUL(0, 0); + MATMUL(1, 0); + + ptr_a0 += 32; + ptr_b0 += 16; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + ACCUMULATE_LAST4_TO_FIRST4(1, 0, tmp); + + ZIP_EVEN_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc0); + + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, tmp32, tmp16, vc0); + + ptr_c0 += 4; + } + + if (m & 2) { + ptr_a0 = ptr_a; + ptr_a += 2 * pad_k; + ptr_b0 = ptr_b; + + INIT_C(0, 0); + + for (BLASLONG p = 0; p < pad_k; p += 8) { + ma0 = svld1_bf16(pg16_all, ptr_a0); + mb0 = svld1_bf16(pg16_all, ptr_b0); + + MATMUL(0, 0); + + ptr_a0 += 16; + ptr_b0 += 16; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + + vc0 = svuzp1(mc00, mc00); + + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, tmp32, tmp16, vc0); + + ptr_c0 += 2; + } + + if (m & 1) { + ptr_a0 = ptr_a; + ptr_b0 = ptr_b; + + INIT_C(0, 0); + for (BLASLONG p = 0; p < pad_k; p += 8) { + + ma0 = svld1_bf16(pg16_all, ptr_a0); + mb0 = svld1_bf16(pg16_all, ptr_b0); + + MATMUL(0, 0); + ptr_a0 += 16; + ptr_b0 += 16; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, tmp32, tmp16, mc00); + } + } + + return 0; +} diff --git a/kernel/generic/gemmkernel_2x2.c b/kernel/generic/gemmkernel_2x2.c index add84f043..8872f2f56 100644 --- a/kernel/generic/gemmkernel_2x2.c +++ b/kernel/generic/gemmkernel_2x2.c @@ -1,26 +1,50 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + #include "common.h" #if defined(BFLOAT16) && defined(BFLOAT16CONVERSION) static float -bfloat16tof32 (bfloat16 f16) +bfloat16tof32 (bfloat16 value) { - float result = 0; - unsigned short* q = (unsigned short*)(&result); -#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - q[0] = f16; -#else - q[1] = f16; -#endif + blasint one = 1; + float result; + sbf16tos_(&one, &value, &one, &result, &one); return result; } -static bfloat16 f32tobfloat16(float f32) { - unsigned short *q = (unsigned short *)(&f32); -#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - return q[0]; -#else - return q[1]; -#endif +#ifdef BGEMM +static bfloat16 f32tobfloat16(float value) { + blasint one = 1; + bfloat16 result; + sbstobf16_(&one, &value, &one, &result, &one); + return result; } +#endif #ifdef BGEMM #define ALPHA bfloat16tof32(alpha) diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index 9eb5959dc..886895acc 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -88,7 +88,11 @@ gotoblas_t TABLE_NAME = { ssymv_LTS, ssymv_UTS, bgemm_kernelTS, bgemm_betaTS, +#if BGEMM_DEFAULT_UNROLL_M != BGEMM_DEFAULT_UNROLL_N bgemm_incopyTS, bgemm_itcopyTS, +#else + bgemm_oncopyTS, bgemm_otcopyTS, +#endif bgemm_oncopyTS, bgemm_otcopyTS, sbgemm_kernelTS, sbgemm_betaTS, diff --git a/param.h b/param.h index 681ff598e..662d08449 100644 --- a/param.h +++ b/param.h @@ -3593,6 +3593,13 @@ is a big desktop or server with abundant cache rather than a phone or embedded d #define GEMM_PREFERED_SIZE 8 #endif +#undef BGEMM_ALIGN_K +#undef BGEMM_DEFAULT_UNROLL_M +#undef BGEMM_DEFAULT_UNROLL_N +#define BGEMM_ALIGN_K 8 +#define BGEMM_DEFAULT_UNROLL_N 4 +#define BGEMM_DEFAULT_UNROLL_M 4 + #undef SBGEMM_ALIGN_K #undef SBGEMM_DEFAULT_UNROLL_M #undef SBGEMM_DEFAULT_UNROLL_N diff --git a/test/compare_sgemm_bgemm.c b/test/compare_sgemm_bgemm.c index 8cddcba97..51d26adbc 100644 --- a/test/compare_sgemm_bgemm.c +++ b/test/compare_sgemm_bgemm.c @@ -33,92 +33,49 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define BGEMM BLASFUNC(bgemm) #define BGEMM_LARGEST 256 -typedef union +static float float16to32(bfloat16 value) { - unsigned short v; -#if defined(_AIX) - struct __attribute__((packed)) -#else - struct -#endif - { -#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - unsigned short s:1; - unsigned short e:8; - unsigned short m:7; -#else - unsigned short m:7; - unsigned short e:8; - unsigned short s:1; -#endif - } bits; -} bfloat16_bits; - -typedef union -{ - float v; -#if defined(_AIX) - struct __attribute__((packed)) -#else - struct -#endif - { -#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - uint32_t s:1; - uint32_t e:8; - uint32_t m:23; -#else - uint32_t m:23; - uint32_t e:8; - uint32_t s:1; -#endif - } bits; -} float32_bits; - -float -float16to32 (bfloat16_bits f16) -{ - float32_bits f32; - f32.bits.s = f16.bits.s; - f32.bits.e = f16.bits.e; - f32.bits.m = (uint32_t) f16.bits.m << 16; - return f32.v; -} - -bfloat16 -float32to16 (float32_bits f32) -{ - bfloat16_bits f16; - f16.bits.s = f32.bits.s; - f16.bits.e = f32.bits.e; - f16.bits.m = (f32.bits.m >> 16) & 0x7f; - return f16.v; + blasint one = 1; + float result; + sbf16tos_(&one, &value, &one, &result, &one); + return result; } static float truncate_float(float value) { - bfloat16_bits f16 = (bfloat16_bits)float32to16((float32_bits)value); - return float16to32(f16); + blasint one = 1; + bfloat16 tmp; + float result; + sbstobf16_(&one, &value, &one, &tmp, &one); + sbf16tos_(&one, &tmp, &one, &result, &one); + return result; } -void *malloc_safe(size_t size) { +static void *malloc_safe(size_t size) { if (size == 0) return malloc(1); else return malloc(size); } +static is_close(float a, float b, float rtol, float atol) { + return fabs(a - b) <= (atol + rtol*fabs(b)); +} + int main (int argc, char *argv[]) { blasint m, n, k; int i, j, l; blasint x, y; + blasint one = 1; int ret = 0; int loop = BGEMM_LARGEST; char transA = 'N', transB = 'N'; float alpha = 1.0, beta = 0.0; - bfloat16 alpha_bf16 = float32to16((float32_bits)alpha); - bfloat16 beta_bf16 = float32to16((float32_bits)beta); + bfloat16 alpha_bf16; + sbstobf16_(&one, &alpha, &one, &alpha_bf16, &one); + bfloat16 beta_bf16; + sbstobf16_(&one, &beta, &one, &beta_bf16, &one); for (x = 0; x <= loop; x++) { @@ -127,23 +84,20 @@ main (int argc, char *argv[]) float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); - bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(m * k * sizeof(bfloat16_bits)); - bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits)); - bfloat16_bits *CC = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits)); + bfloat16 *AA = (bfloat16 *)malloc_safe(m * k * sizeof(bfloat16)); + bfloat16 *BB = (bfloat16 *)malloc_safe(k * n * sizeof(bfloat16)); + bfloat16 *CC = (bfloat16 *)malloc_safe(k * n * sizeof(bfloat16)); FLOAT *DD = (FLOAT *)malloc_safe(m * n * sizeof(FLOAT)); if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || (DD == NULL) || (CC == NULL)) return 1; - bfloat16 atmp,btmp; - blasint one=1; for (j = 0; j < m; j++) { for (i = 0; i < k; i++) { A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one); - AA[j * k + i].v = atmp; + sbstobf16_(&one, &A[j*k+i], &one, &AA[j * k + i], &one); } } for (j = 0; j < n; j++) @@ -151,8 +105,7 @@ main (int argc, char *argv[]) for (i = 0; i < k; i++) { B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one); - BB[j * k + i].v = btmp; + sbstobf16_(&one, &B[j*k+i], &one, &BB[j * k + i], &one); } } for (y = 0; y < 4; y++) @@ -168,7 +121,7 @@ main (int argc, char *argv[]) transB = 'T'; } - memset(CC, 0, m * n * sizeof(bfloat16_bits)); + memset(CC, 0, m * n * sizeof(bfloat16)); memset(DD, 0, m * n * sizeof(FLOAT)); memset(C, 0, m * n * sizeof(FLOAT)); @@ -198,10 +151,15 @@ main (int argc, char *argv[]) DD[i * m + j] += float16to32 (AA[k * j + l]) * float16to32 (BB[i + l * n]); } - if (fabs(float16to32(CC[i * m + j]) - truncate_float(C[i * m + j])) > 2.0) { + if (!is_close(float16to32(CC[i * m + j]), truncate_float(C[i * m + j]), 0.01, 0.001)) { + printf("Mismatch at i=%d, j=%d, k=%d: CC=%.6f, C=%.6f\n", + i, j, k, float16to32(CC[i * m + j]), truncate_float(C[i * m + j])); ret++; } - if (fabs(float16to32(CC[i * m + j]) - truncate_float(DD[i * m + j])) > 1.0) { + + if (!is_close(float16to32(CC[i * m + j]), truncate_float(DD[i * m + j]), 0.0001, 0.00001)) { + printf("Mismatch at i=%d, j=%d, k=%d: CC=%.6f, DD=%.6f\n", + i, j, k, float16to32(CC[i * m + j]), truncate_float(DD[i * m + j])); ret++; }