Add optimized BGEMM kernel for NEOVERSEV1 targetpull/5374/head
| @@ -1,3 +1,31 @@ | |||||
| ############################################################################### | |||||
| # Copyright (c) 2025, The OpenBLAS Project | |||||
| # All rights reserved. | |||||
| # Redistribution and use in source and binary forms, with or without | |||||
| # modification, are permitted provided that the following conditions are | |||||
| # met: | |||||
| # 1. Redistributions of source code must retain the above copyright | |||||
| # notice, this list of conditions and the following disclaimer. | |||||
| # 2. Redistributions in binary form must reproduce the above copyright | |||||
| # notice, this list of conditions and the following disclaimer in | |||||
| # the documentation and/or other materials provided with the | |||||
| # distribution. | |||||
| # 3. Neither the name of the OpenBLAS project nor the names of | |||||
| # its contributors may be used to endorse or promote products | |||||
| # derived from this software without specific prior written permission. | |||||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
| # ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
| # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |||||
| # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |||||
| # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |||||
| # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |||||
| # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | |||||
| # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |||||
| # POSSIBILITY OF SUCH DAMAGE. | |||||
| ############################################################################### | |||||
| TOPDIR = .. | TOPDIR = .. | ||||
| include $(TOPDIR)/Makefile.system | include $(TOPDIR)/Makefile.system | ||||
| @@ -56,7 +84,7 @@ GOTO_LAPACK_TARGETS= | |||||
| endif | endif | ||||
| ifeq ($(BUILD_BFLOAT16),1) | ifeq ($(BUILD_BFLOAT16),1) | ||||
| GOTO_BFLOAT_TARGETS=sbgemm.goto | |||||
| GOTO_BFLOAT_TARGETS=bgemm.goto sbgemm.goto | |||||
| else | else | ||||
| GOTO_BFLOAT_TARGETS= | GOTO_BFLOAT_TARGETS= | ||||
| endif | endif | ||||
| @@ -635,6 +663,8 @@ zcholesky.essl : zcholesky.$(SUFFIX) | |||||
| ##################################### Sgemm #################################################### | ##################################### Sgemm #################################################### | ||||
| ifeq ($(BUILD_BFLOAT16),1) | ifeq ($(BUILD_BFLOAT16),1) | ||||
| bgemm.goto : bgemm.$(SUFFIX) ../$(LIBNAME) | |||||
| $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | |||||
| sbgemm.goto : sbgemm.$(SUFFIX) ../$(LIBNAME) | sbgemm.goto : sbgemm.$(SUFFIX) ../$(LIBNAME) | ||||
| $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | ||||
| endif | endif | ||||
| @@ -2970,6 +3000,8 @@ zcholesky.$(SUFFIX) : cholesky.c | |||||
| $(CC) $(CFLAGS) -c -DCOMPLEX -DDOUBLE -o $(@F) $^ | $(CC) $(CFLAGS) -c -DCOMPLEX -DDOUBLE -o $(@F) $^ | ||||
| ifeq ($(BUILD_BFLOAT16),1) | ifeq ($(BUILD_BFLOAT16),1) | ||||
| bgemm.$(SUFFIX) : gemm.c | |||||
| $(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UCOMPLEX -UDOUBLE -o $(@F) $^ | |||||
| sbgemm.$(SUFFIX) : gemm.c | sbgemm.$(SUFFIX) : gemm.c | ||||
| $(CC) $(CFLAGS) -c -DBFLOAT16 -UCOMPLEX -UDOUBLE -o $(@F) $^ | $(CC) $(CFLAGS) -c -DBFLOAT16 -UCOMPLEX -UDOUBLE -o $(@F) $^ | ||||
| endif | endif | ||||
| @@ -33,6 +33,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| #ifdef DOUBLE | #ifdef DOUBLE | ||||
| #define GEMM BLASFUNC(dgemm) | #define GEMM BLASFUNC(dgemm) | ||||
| #elif defined(BFLOAT16) && defined(BGEMM) | |||||
| #define GEMM BLASFUNC(bgemm) | |||||
| #elif defined(BFLOAT16) | #elif defined(BFLOAT16) | ||||
| #define GEMM BLASFUNC(sbgemm) | #define GEMM BLASFUNC(sbgemm) | ||||
| #undef IFLOAT | #undef IFLOAT | ||||
| @@ -60,8 +62,18 @@ int main(int argc, char *argv[]){ | |||||
| IFLOAT *a, *b; | IFLOAT *a, *b; | ||||
| FLOAT *c; | FLOAT *c; | ||||
| #ifdef BGEMM | |||||
| blasint one=1; | |||||
| blasint two=2; | |||||
| float alpha_in[] = {1.0, 0.0}; | |||||
| float beta_in[] = {0.0, 0.0}; | |||||
| FLOAT alpha[2], beta[2]; | |||||
| sbstobf16_(&two, alpha_in, &one, alpha, &one); | |||||
| sbstobf16_(&two, beta_in, &one, beta, &one); | |||||
| #else | |||||
| FLOAT alpha[] = {1.0, 0.0}; | FLOAT alpha[] = {1.0, 0.0}; | ||||
| FLOAT beta [] = {0.0, 0.0}; | FLOAT beta [] = {0.0, 0.0}; | ||||
| #endif | |||||
| char transa = 'N'; | char transa = 'N'; | ||||
| char transb = 'N'; | char transb = 'N'; | ||||
| blasint m, n, k, i, j, lda, ldb, ldc; | blasint m, n, k, i, j, lda, ldb, ldc; | ||||
| @@ -30,10 +30,16 @@ | |||||
| #define COMMON_B_H | #define COMMON_B_H | ||||
| #ifndef DYNAMIC_ARCH | #ifndef DYNAMIC_ARCH | ||||
| #define BGEMM_ONCOPY bgemm_oncopy | |||||
| #define BGEMM_OTCOPY bgemm_otcopy | |||||
| #define BGEMM_INCOPY bgemm_incopy | |||||
| #define BGEMM_ITCOPY bgemm_itcopy | |||||
| #define BGEMM_ONCOPY bgemm_oncopy | |||||
| #define BGEMM_OTCOPY bgemm_otcopy | |||||
| #if BGEMM_DEFAULT_UNROLL_M == BGEMM_DEFAULT_UNROLL_N | |||||
| #define BGEMM_INCOPY bgemm_oncopy | |||||
| #define BGEMM_ITCOPY bgemm_otcopy | |||||
| #else | |||||
| #define BGEMM_INCOPY bgemm_incopy | |||||
| #define BGEMM_ITCOPY bgemm_itcopy | |||||
| #endif | |||||
| #define BGEMM_BETA bgemm_beta | #define BGEMM_BETA bgemm_beta | ||||
| #define BGEMM_KERNEL bgemm_kernel | #define BGEMM_KERNEL bgemm_kernel | ||||
| @@ -1,3 +1,31 @@ | |||||
| ############################################################################### | |||||
| # Copyright (c) 2025, The OpenBLAS Project | |||||
| # All rights reserved. | |||||
| # Redistribution and use in source and binary forms, with or without | |||||
| # modification, are permitted provided that the following conditions are | |||||
| # met: | |||||
| # 1. Redistributions of source code must retain the above copyright | |||||
| # notice, this list of conditions and the following disclaimer. | |||||
| # 2. Redistributions in binary form must reproduce the above copyright | |||||
| # notice, this list of conditions and the following disclaimer in | |||||
| # the documentation and/or other materials provided with the | |||||
| # distribution. | |||||
| # 3. Neither the name of the OpenBLAS project nor the names of | |||||
| # its contributors may be used to endorse or promote products | |||||
| # derived from this software without specific prior written permission. | |||||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
| # ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
| # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |||||
| # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |||||
| # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |||||
| # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |||||
| # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | |||||
| # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |||||
| # POSSIBILITY OF SUCH DAMAGE. | |||||
| ############################################################################### | |||||
| TOPDIR = .. | TOPDIR = .. | ||||
| include $(TOPDIR)/Makefile.system | include $(TOPDIR)/Makefile.system | ||||
| @@ -526,7 +554,7 @@ ifneq ($(BUILD_COMPLEX16),1) | |||||
| ZBLASOBJS= | ZBLASOBJS= | ||||
| endif | endif | ||||
| FUNCOBJS = $(SBEXTOBJS) $(CXERBLAOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(SHBLASOBJS) | |||||
| FUNCOBJS = $(SBEXTOBJS) $(CXERBLAOBJS) $(BBLASOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(SHBLASOBJS) | |||||
| ifeq ($(EXPRECISION), 1) | ifeq ($(EXPRECISION), 1) | ||||
| FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS) | FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS) | ||||
| @@ -674,6 +674,10 @@ ZBLASOBJS += \ | |||||
| endif | endif | ||||
| ifeq ($(BUILD_BFLOAT16), 1) | ifeq ($(BUILD_BFLOAT16), 1) | ||||
| BGEMMINCOPYOBJ_P = $(BGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | |||||
| BGEMMITCOPYOBJ_P = $(BGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | |||||
| BGEMMONCOPYOBJ_P = $(BGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | |||||
| BGEMMOTCOPYOBJ_P = $(BGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | |||||
| SBGEMMINCOPYOBJ_P = $(SBGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | SBGEMMINCOPYOBJ_P = $(SBGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | ||||
| SBGEMMITCOPYOBJ_P = $(SBGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | SBGEMMITCOPYOBJ_P = $(SBGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | ||||
| SBGEMMONCOPYOBJ_P = $(SBGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | SBGEMMONCOPYOBJ_P = $(SBGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) | ||||
| @@ -2998,6 +3002,20 @@ $(KDIR)xgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(XGEMM_BETA) | |||||
| ifeq ($(BUILD_BFLOAT16), 1) | ifeq ($(BUILD_BFLOAT16), 1) | ||||
| $(BGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(BGEMMONCOPY) | |||||
| $(CC) $(PFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ | |||||
| $(BGEMMOTCOPYOBJ_P) : $(KERNELDIR)/$(BGEMMOTCOPY) | |||||
| $(CC) $(PFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ | |||||
| ifneq ($(BGEMM_UNROLL_M), $(BGEMM_UNROLL_N)) | |||||
| $(BGEMMINCOPYOBJ_P) : $(KERNELDIR)/$(BGEMMINCOPY) | |||||
| $(CC) $(PFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ | |||||
| $(BGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(BGEMMITCOPY) | |||||
| $(CC) $(PFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ | |||||
| endif | |||||
| $(SBGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SBGEMMONCOPY) | $(SBGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SBGEMMONCOPY) | ||||
| $(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ | $(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ | ||||
| @@ -3010,7 +3028,6 @@ $(SBGEMMINCOPYOBJ_P) : $(KERNELDIR)/$(SBGEMMINCOPY) | |||||
| $(SBGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(SBGEMMITCOPY) | $(SBGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(SBGEMMITCOPY) | ||||
| $(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ | $(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ | ||||
| endif | endif | ||||
| endif | endif | ||||
| @@ -3137,6 +3154,8 @@ endif | |||||
| ifeq ($(BUILD_BFLOAT16), 1) | ifeq ($(BUILD_BFLOAT16), 1) | ||||
| $(KDIR)bgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(BGEMMKERNEL) $(BGEMMDEPEND) | |||||
| $(CC) $(PFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ | |||||
| $(KDIR)sbgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SBGEMMKERNEL) $(SBGEMMDEPEND) | $(KDIR)sbgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SBGEMMKERNEL) $(SBGEMMDEPEND) | ||||
| $(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ | $(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ | ||||
| endif | endif | ||||
| @@ -1,3 +1,31 @@ | |||||
| ############################################################################### | |||||
| # Copyright (c) 2025, The OpenBLAS Project | |||||
| # All rights reserved. | |||||
| # Redistribution and use in source and binary forms, with or without | |||||
| # modification, are permitted provided that the following conditions are | |||||
| # met: | |||||
| # 1. Redistributions of source code must retain the above copyright | |||||
| # notice, this list of conditions and the following disclaimer. | |||||
| # 2. Redistributions in binary form must reproduce the above copyright | |||||
| # notice, this list of conditions and the following disclaimer in | |||||
| # the documentation and/or other materials provided with the | |||||
| # distribution. | |||||
| # 3. Neither the name of the OpenBLAS project nor the names of | |||||
| # its contributors may be used to endorse or promote products | |||||
| # derived from this software without specific prior written permission. | |||||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
| # ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
| # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |||||
| # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |||||
| # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |||||
| # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |||||
| # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | |||||
| # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |||||
| # POSSIBILITY OF SUCH DAMAGE. | |||||
| ############################################################################### | |||||
| include $(KERNELDIR)/KERNEL.ARMV8SVE | include $(KERNELDIR)/KERNEL.ARMV8SVE | ||||
| SGEMVNKERNEL = gemv_n_sve_v1x3.c | SGEMVNKERNEL = gemv_n_sve_v1x3.c | ||||
| @@ -5,6 +33,19 @@ DGEMVNKERNEL = gemv_n_sve_v1x3.c | |||||
| SGEMVTKERNEL = gemv_t_sve_v1x3.c | SGEMVTKERNEL = gemv_t_sve_v1x3.c | ||||
| DGEMVTKERNEL = gemv_t_sve_v1x3.c | DGEMVTKERNEL = gemv_t_sve_v1x3.c | ||||
| ifeq ($(BUILD_BFLOAT16), 1) | ifeq ($(BUILD_BFLOAT16), 1) | ||||
| BGEMM_BETA = bgemm_beta_neon.c | |||||
| BGEMMKERNEL = bgemm_kernel_$(BGEMM_UNROLL_M)x$(BGEMM_UNROLL_N)_neoversev1.c | |||||
| ifneq ($(BGEMM_UNROLL_M), $(BGEMM_UNROLL_N)) | |||||
| BGEMMINCOPY = sbgemm_ncopy_$(SBGEMM_UNROLL_M)_neoversev1.c | |||||
| BGEMMITCOPY = sbgemm_tcopy_$(SBGEMM_UNROLL_M)_neoversev1.c | |||||
| BGEMMINCOPYOBJ = bgemm_incopy$(TSUFFIX).$(SUFFIX) | |||||
| BGEMMITCOPYOBJ = bgemm_itcopy$(TSUFFIX).$(SUFFIX) | |||||
| endif | |||||
| BGEMMONCOPY = sbgemm_ncopy_$(BGEMM_UNROLL_N)_neoversev1.c | |||||
| BGEMMOTCOPY = sbgemm_tcopy_$(BGEMM_UNROLL_N)_neoversev1.c | |||||
| BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX) | |||||
| BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX) | |||||
| SBGEMM_BETA = sbgemm_beta_neoversev1.c | SBGEMM_BETA = sbgemm_beta_neoversev1.c | ||||
| SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversev1.c | SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversev1.c | ||||
| ifneq ($(SBGEMM_UNROLL_M), $(SBGEMM_UNROLL_N)) | ifneq ($(SBGEMM_UNROLL_M), $(SBGEMM_UNROLL_N)) | ||||
| @@ -21,4 +62,4 @@ SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX) | |||||
| SBGEMVNKERNEL = sbgemv_n_neon.c | SBGEMVNKERNEL = sbgemv_n_neon.c | ||||
| SBGEMVTKERNEL = sbgemv_t_bfdot.c | SBGEMVTKERNEL = sbgemv_t_bfdot.c | ||||
| endif | |||||
| endif | |||||
| @@ -0,0 +1,107 @@ | |||||
| /*************************************************************************** | |||||
| * Copyright (c) 2025, The OpenBLAS Project | |||||
| * All rights reserved. | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| * modification, are permitted provided that the following conditions are | |||||
| * met: | |||||
| * 1. Redistributions of source code must retain the above copyright | |||||
| * notice, this list of conditions and the following disclaimer. | |||||
| * 2. Redistributions in binary form must reproduce the above copyright | |||||
| * notice, this list of conditions and the following disclaimer in | |||||
| * the documentation and/or other materials provided with the | |||||
| * distribution. | |||||
| * 3. Neither the name of the OpenBLAS project nor the names of | |||||
| * its contributors may be used to endorse or promote products | |||||
| * derived from this software without specific prior written permission. | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
| * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
| * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |||||
| * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |||||
| * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |||||
| * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |||||
| * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | |||||
| * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |||||
| * POSSIBILITY OF SUCH DAMAGE. | |||||
| * *****************************************************************************/ | |||||
| #include "common.h" | |||||
| #include <arm_neon.h> | |||||
| int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta_in, IFLOAT *dummy2, | |||||
| BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, FLOAT *c, | |||||
| BLASLONG ldc) { | |||||
| BLASLONG i, j; | |||||
| BLASLONG chunk, remain; | |||||
| bfloat16_t *ptr_c, *ptr_c0; | |||||
| bfloat16x8_t x0, z0; | |||||
| float32x4_t y0, y1; | |||||
| float x; | |||||
| bfloat16_t z; | |||||
| bfloat16_t zero_bf16 = vcvth_bf16_f32(0.0f); | |||||
| bfloat16x8_t zeros = vdupq_n_bf16(zero_bf16); | |||||
| bfloat16_t beta_bf16; | |||||
| memcpy(&beta_bf16, &beta_in, sizeof(bfloat16_t)); | |||||
| float beta = vcvtah_f32_bf16(beta_bf16); | |||||
| float32x4_t beta_neon = vdupq_n_f32(beta); | |||||
| ptr_c = (bfloat16_t *)c; | |||||
| chunk = m >> 3; | |||||
| remain = m & 7; | |||||
| if (beta == 0.0f){ | |||||
| for (j = 0; j < n; j ++){ | |||||
| ptr_c0 = ptr_c; | |||||
| ptr_c += ldc; | |||||
| for (i = 0; i < chunk; i ++){ | |||||
| vst1q_bf16(ptr_c0, zeros); | |||||
| ptr_c0 += 8; | |||||
| } | |||||
| for (i = 0; i < remain; i ++){ | |||||
| ptr_c0[0] = zero_bf16; | |||||
| ptr_c0 ++; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| for (j = 0; j < n; j ++){ | |||||
| ptr_c0 = ptr_c; | |||||
| ptr_c += ldc; | |||||
| for (i = 0; i < chunk; i ++){ | |||||
| x0 = vld1q_bf16(ptr_c0); | |||||
| y0 = vcvtq_low_f32_bf16(x0); | |||||
| y1 = vcvtq_high_f32_bf16(x0); | |||||
| y0 = vmulq_f32(y0, beta_neon); | |||||
| y1 = vmulq_f32(y1, beta_neon); | |||||
| z0 = vcvtq_low_bf16_f32(y0); | |||||
| z0 = vcvtq_high_bf16_f32(z0, y1); | |||||
| vst1q_bf16(ptr_c0, z0); | |||||
| ptr_c0 += 8; | |||||
| } | |||||
| for (i = 0; i < remain; i ++){ | |||||
| x = vcvtah_f32_bf16(ptr_c0[0]); | |||||
| z = vcvth_bf16_f32(x * beta); | |||||
| ptr_c0[0] = z; | |||||
| ptr_c0 ++; | |||||
| } | |||||
| } | |||||
| } | |||||
| return 0; | |||||
| }; | |||||
| @@ -0,0 +1,50 @@ | |||||
| /*************************************************************************** | |||||
| * Copyright (c) 2025, The OpenBLAS Project | |||||
| * All rights reserved. | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| * modification, are permitted provided that the following conditions are | |||||
| * met: | |||||
| * 1. Redistributions of source code must retain the above copyright | |||||
| * notice, this list of conditions and the following disclaimer. | |||||
| * 2. Redistributions in binary form must reproduce the above copyright | |||||
| * notice, this list of conditions and the following disclaimer in | |||||
| * the documentation and/or other materials provided with the | |||||
| * distribution. | |||||
| * 3. Neither the name of the OpenBLAS project nor the names of | |||||
| * its contributors may be used to endorse or promote products | |||||
| * derived from this software without specific prior written permission. | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
| * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
| * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |||||
| * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |||||
| * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |||||
| * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |||||
| * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | |||||
| * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |||||
| * POSSIBILITY OF SUCH DAMAGE. | |||||
| * *****************************************************************************/ | |||||
| #include <arm_sve.h> | |||||
| #include "common.h" | |||||
| #define ALPHA_ONE | |||||
| #include "bgemm_kernel_4x4_neoversev1_impl.c" | |||||
| #undef ALPHA_ONE | |||||
| #undef UPDATE_C | |||||
| #include "bgemm_kernel_4x4_neoversev1_impl.c" | |||||
| int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B, | |||||
| FLOAT *C, BLASLONG ldc) { | |||||
| bfloat16_t alpha_bf16; | |||||
| memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t)); | |||||
| float alpha_f32 = vcvtah_f32_bf16(alpha_bf16); | |||||
| if (alpha_f32 == 1.0f) | |||||
| return bgemm_kernel_neoversev1_alpha_one(m, n, k, alpha, A, B, C, ldc); | |||||
| else | |||||
| return bgemm_kernel_neoversev1_alpha(m, n, k, alpha, A, B, C, ldc); | |||||
| return 0; | |||||
| } | |||||
| @@ -0,0 +1,429 @@ | |||||
| /*************************************************************************** | |||||
| * Copyright (c) 2025, The OpenBLAS Project | |||||
| * All rights reserved. | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| * modification, are permitted provided that the following conditions are | |||||
| * met: | |||||
| * 1. Redistributions of source code must retain the above copyright | |||||
| * notice, this list of conditions and the following disclaimer. | |||||
| * 2. Redistributions in binary form must reproduce the above copyright | |||||
| * notice, this list of conditions and the following disclaimer in | |||||
| * the documentation and/or other materials provided with the | |||||
| * distribution. | |||||
| * 3. Neither the name of the OpenBLAS project nor the names of | |||||
| * its contributors may be used to endorse or promote products | |||||
| * derived from this software without specific prior written permission. | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
| * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
| * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |||||
| * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |||||
| * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |||||
| * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |||||
| * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | |||||
| * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |||||
| * POSSIBILITY OF SUCH DAMAGE. | |||||
| * *****************************************************************************/ | |||||
| #include <arm_sve.h> | |||||
| #include "common.h" | |||||
| #define INIT_C(M, N) mc##M##N = svdup_f32(0); | |||||
| #define MATMUL(M, N) mc##M##N = svbfmmla(mc##M##N, ma##M, mb##N); | |||||
| #define INIT_C_4x4 \ | |||||
| do { \ | |||||
| INIT_C(0, 0); \ | |||||
| INIT_C(0, 1); \ | |||||
| INIT_C(1, 0); \ | |||||
| INIT_C(1, 1); \ | |||||
| } while (0); | |||||
| #ifdef ALPHA_ONE | |||||
| #define UPDATE_C(PG16, PG32, PTR, TMP32, TMP16, SRC32) \ | |||||
| do { \ | |||||
| TMP32 = svreinterpret_f32_u32(svld1uh_u32((PG16), (uint16_t*)PTR)); \ | |||||
| TMP32 = svadd_z((PG32), SRC32, TMP32); \ | |||||
| TMP16 = svcvt_bf16_f32_z((PG32), TMP32); \ | |||||
| TMP16 = svuzp1_bf16(TMP16, TMP16); \ | |||||
| svst1_bf16((PG16), (PTR), TMP16); \ | |||||
| } while (0) | |||||
| #else | |||||
| #define UPDATE_C(PG16, PG32, PTR, TMP32, TMP16, SRC32) \ | |||||
| do { \ | |||||
| TMP32 = svreinterpret_f32_u32(svld1uh_u32((PG16), (uint16_t*)PTR)); \ | |||||
| TMP32 = svmad_z((PG32), svalpha, SRC32, TMP32); \ | |||||
| TMP16 = svcvt_bf16_f32_z((PG32), TMP32); \ | |||||
| TMP16 = svuzp1_bf16(TMP16, TMP16); \ | |||||
| svst1_bf16((PG16), (PTR), TMP16); \ | |||||
| } while (0) | |||||
| #endif | |||||
| #define ZIP_EVEN_ELEMENTS(PG, mc0, mc1, tmp, vc) \ | |||||
| do { \ | |||||
| (tmp) = svuzp1_f32((mc0), (mc1)); \ | |||||
| (vc) = svcompact_f32((PG), (tmp)); \ | |||||
| } while (0) | |||||
| #define ZIP_ODD_ELEMENTS(PG, mc0, mc1, tmp, vc) \ | |||||
| do { \ | |||||
| (tmp) = svuzp2_f32((mc0), (mc1)); \ | |||||
| (vc) = svcompact_f32((PG), (tmp)); \ | |||||
| } while (0) | |||||
| #define ACCUMULATE_LAST4_TO_FIRST4(M, N, TMP) \ | |||||
| do { \ | |||||
| TMP = svext_f32(mc##M##N, mc##M##N, 4); \ | |||||
| mc##M##N = svadd_f32_z(svptrue_b32(), mc##M##N, (TMP)); \ | |||||
| } while (0) | |||||
| #ifdef ALPHA_ONE | |||||
| int bgemm_kernel_neoversev1_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, | |||||
| FLOAT alpha, IFLOAT *A, IFLOAT *B, | |||||
| FLOAT *C, BLASLONG ldc) | |||||
| #else | |||||
| int bgemm_kernel_neoversev1_alpha(BLASLONG m, BLASLONG n, BLASLONG k, | |||||
| FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, | |||||
| BLASLONG ldc) | |||||
| #endif | |||||
| { | |||||
| BLASLONG pad_k = (k + 7) & ~7; | |||||
| svbfloat16_t ma0, ma1, mb0, mb1; | |||||
| svfloat32_t mc00, mc01, mc10, mc11, vc0, vc1, vc2, vc3; | |||||
| svfloat32_t tmp; | |||||
| #ifndef ALPHA_ONE | |||||
| bfloat16_t alpha_bf16; | |||||
| memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t)); | |||||
| svfloat32_t svalpha = svdup_f32(vcvtah_f32_bf16(alpha_bf16)); | |||||
| #endif | |||||
| svbool_t pg16_all = svptrue_b16(); | |||||
| svbool_t pg32_first_1 = svwhilelt_b32(0, 1); | |||||
| svbool_t pg32_first_2 = svwhilelt_b32(0, 2); | |||||
| svbool_t pg32_first_4 = svwhilelt_b32(0, 4); | |||||
| svbool_t pg16_first_1 = svwhilelt_b16(0, 1); | |||||
| svbool_t pg16_first_2 = svwhilelt_b16(0, 2); | |||||
| svbool_t pg16_first_4 = svwhilelt_b16(0, 4); | |||||
| svbool_t pg32_select_first_2_per_quadword = svdupq_b32(1, 1, 0, 0); | |||||
| bfloat16_t *ptr_a = (bfloat16_t *)A; | |||||
| bfloat16_t *ptr_b = (bfloat16_t *)B; | |||||
| bfloat16_t *ptr_c = (bfloat16_t *)C; | |||||
| bfloat16_t *ptr_a0; | |||||
| bfloat16_t *ptr_b0; | |||||
| bfloat16_t *ptr_c0, *ptr_c1, *ptr_c2, *ptr_c3; | |||||
| svfloat32_t tmp32; | |||||
| svbfloat16_t tmp16; | |||||
| for (BLASLONG j = 0; j < n / 4; j++) { | |||||
| ptr_c0 = ptr_c; | |||||
| ptr_c1 = ptr_c0 + ldc; | |||||
| ptr_c2 = ptr_c1 + ldc; | |||||
| ptr_c3 = ptr_c2 + ldc; | |||||
| ptr_c += 4 * ldc; | |||||
| ptr_a = (bfloat16_t *)A; | |||||
| for (BLASLONG i = 0; i < m / 4; i++) { | |||||
| ptr_a0 = ptr_a; | |||||
| ptr_a += 4 * pad_k; | |||||
| ptr_b0 = ptr_b; | |||||
| INIT_C_4x4; | |||||
| for (BLASLONG p = 0; p < pad_k; p += 8) { | |||||
| ma0 = svld1_bf16(pg16_all, ptr_a0); | |||||
| ma1 = svld1_bf16(pg16_all, ptr_a0 + 16); | |||||
| mb0 = svld1_bf16(pg16_all, ptr_b0); | |||||
| mb1 = svld1_bf16(pg16_all, ptr_b0 + 16); | |||||
| MATMUL(0, 0); | |||||
| MATMUL(0, 1); | |||||
| MATMUL(1, 0); | |||||
| MATMUL(1, 1); | |||||
| ptr_a0 += 32; | |||||
| ptr_b0 += 32; | |||||
| } | |||||
| ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); | |||||
| ACCUMULATE_LAST4_TO_FIRST4(0, 1, tmp); | |||||
| ACCUMULATE_LAST4_TO_FIRST4(1, 0, tmp); | |||||
| ACCUMULATE_LAST4_TO_FIRST4(1, 1, tmp); | |||||
| ZIP_EVEN_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc0); | |||||
| ZIP_ODD_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc1); | |||||
| ZIP_EVEN_ELEMENTS(pg32_select_first_2_per_quadword, mc01, mc11, tmp, vc2); | |||||
| ZIP_ODD_ELEMENTS(pg32_select_first_2_per_quadword, mc01, mc11, tmp, vc3); | |||||
| UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, tmp32, tmp16, vc0); | |||||
| UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, tmp32, tmp16, vc1); | |||||
| UPDATE_C(pg16_first_4, pg32_first_4, ptr_c2, tmp32, tmp16, vc2); | |||||
| UPDATE_C(pg16_first_4, pg32_first_4, ptr_c3, tmp32, tmp16, vc3); | |||||
| ptr_c0 += 4; | |||||
| ptr_c1 += 4; | |||||
| ptr_c2 += 4; | |||||
| ptr_c3 += 4; | |||||
| } | |||||
| if (m & 2) { | |||||
| ptr_a0 = ptr_a; | |||||
| ptr_a += 2 * pad_k; | |||||
| ptr_b0 = ptr_b; | |||||
| INIT_C(0, 0); | |||||
| INIT_C(0, 1); | |||||
| for (BLASLONG p = 0; p < pad_k; p += 8) { | |||||
| ma0 = svld1_bf16(pg16_all, ptr_a0); | |||||
| mb0 = svld1_bf16(pg16_all, ptr_b0); | |||||
| mb1 = svld1_bf16(pg16_all, ptr_b0 + 16); | |||||
| MATMUL(0, 0); | |||||
| MATMUL(0, 1); | |||||
| ptr_a0 += 16; | |||||
| ptr_b0 += 32; | |||||
| } | |||||
| ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); | |||||
| ACCUMULATE_LAST4_TO_FIRST4(0, 1, tmp); | |||||
| vc0 = svuzp1(mc00, mc00); | |||||
| vc1 = svuzp2(mc00, mc00); | |||||
| vc2 = svuzp1(mc01, mc01); | |||||
| vc3 = svuzp2(mc01, mc01); | |||||
| UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, tmp32, tmp16, vc0); | |||||
| UPDATE_C(pg16_first_2, pg32_first_2, ptr_c1, tmp32, tmp16, vc1); | |||||
| UPDATE_C(pg16_first_2, pg32_first_2, ptr_c2, tmp32, tmp16, vc2); | |||||
| UPDATE_C(pg16_first_2, pg32_first_2, ptr_c3, tmp32, tmp16, vc3); | |||||
| ptr_c0 += 2; | |||||
| ptr_c1 += 2; | |||||
| ptr_c2 += 2; | |||||
| ptr_c3 += 2; | |||||
| } | |||||
| if (m & 1) { | |||||
| ptr_a0 = ptr_a; | |||||
| ptr_b0 = ptr_b; | |||||
| INIT_C(0, 0); | |||||
| INIT_C(0, 1); | |||||
| for (BLASLONG p = 0; p < pad_k; p += 8) { | |||||
| ma0 = svld1_bf16(pg16_all, ptr_a0); | |||||
| mb0 = svld1_bf16(pg16_all, ptr_b0); | |||||
| mb1 = svld1_bf16(pg16_all, ptr_b0 + 16); | |||||
| MATMUL(0, 0); | |||||
| MATMUL(0, 1); | |||||
| ptr_a0 += 16; | |||||
| ptr_b0 += 32; | |||||
| } | |||||
| ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); | |||||
| ACCUMULATE_LAST4_TO_FIRST4(0, 1, tmp); | |||||
| // use compact is more straightforward | |||||
| vc1 = svuzp2(mc00, mc00); | |||||
| vc3 = svuzp2(mc01, mc01); | |||||
| UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, tmp32, tmp16, mc00); | |||||
| UPDATE_C(pg16_first_1, pg32_first_1, ptr_c1, tmp32, tmp16, vc1); | |||||
| UPDATE_C(pg16_first_1, pg32_first_1, ptr_c2, tmp32, tmp16, mc01); | |||||
| UPDATE_C(pg16_first_1, pg32_first_1, ptr_c3, tmp32, tmp16, vc3); | |||||
| } | |||||
| ptr_b += 4 * pad_k; | |||||
| } | |||||
| if (n & 2) { | |||||
| ptr_c0 = ptr_c; | |||||
| ptr_c1 = ptr_c0 + ldc; | |||||
| ptr_c += 2 * ldc; | |||||
| ptr_a = (bfloat16_t *)A; | |||||
| for (BLASLONG i = 0; i < m / 4; i++) { | |||||
| ptr_a0 = ptr_a; | |||||
| ptr_a += 4 * pad_k; | |||||
| ptr_b0 = ptr_b; | |||||
| INIT_C(0, 0); | |||||
| INIT_C(1, 0); | |||||
| for (BLASLONG p = 0; p < pad_k; p += 8) { | |||||
| ma0 = svld1_bf16(pg16_all, ptr_a0); | |||||
| ma1 = svld1_bf16(pg16_all, ptr_a0 + 16); | |||||
| mb0 = svld1_bf16(pg16_all, ptr_b0); | |||||
| MATMUL(0, 0); | |||||
| MATMUL(1, 0); | |||||
| ptr_a0 += 32; | |||||
| ptr_b0 += 16; | |||||
| } | |||||
| ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); | |||||
| ACCUMULATE_LAST4_TO_FIRST4(1, 0, tmp); | |||||
| ZIP_EVEN_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc0); | |||||
| ZIP_ODD_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc2); | |||||
| UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, tmp32, tmp16, vc0); | |||||
| UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, tmp32, tmp16, vc2); | |||||
| ptr_c0 += 4; | |||||
| ptr_c1 += 4; | |||||
| } | |||||
| if (m & 2) { | |||||
| ptr_a0 = ptr_a; | |||||
| ptr_a += 2 * pad_k; | |||||
| ptr_b0 = ptr_b; | |||||
| INIT_C(0, 0); | |||||
| for (BLASLONG p = 0; p < pad_k; p += 8) { | |||||
| ma0 = svld1_bf16(pg16_all, ptr_a0); | |||||
| mb0 = svld1_bf16(pg16_all, ptr_b0); | |||||
| MATMUL(0, 0); | |||||
| ptr_a0 += 16; | |||||
| ptr_b0 += 16; | |||||
| } | |||||
| ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); | |||||
| vc0 = svuzp1(mc00, mc00); | |||||
| vc1 = svuzp2(mc00, mc00); | |||||
| UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, tmp32, tmp16, vc0); | |||||
| UPDATE_C(pg16_first_2, pg32_first_2, ptr_c1, tmp32, tmp16, vc1); | |||||
| ptr_c0 += 2; | |||||
| ptr_c1 += 2; | |||||
| } | |||||
| if (m & 1) { | |||||
| ptr_a0 = ptr_a; | |||||
| ptr_b0 = ptr_b; | |||||
| INIT_C(0, 0); | |||||
| for (BLASLONG p = 0; p < pad_k; p += 8) { | |||||
| ma0 = svld1_bf16(pg16_all, ptr_a0); | |||||
| mb0 = svld1_bf16(pg16_all, ptr_b0); | |||||
| MATMUL(0, 0); | |||||
| ptr_a0 += 16; | |||||
| ptr_b0 += 16; | |||||
| } | |||||
| ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); | |||||
| vc1 = svuzp2(mc00, mc00); | |||||
| UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, tmp32, tmp16, mc00); | |||||
| UPDATE_C(pg16_first_1, pg32_first_1, ptr_c1, tmp32, tmp16, vc1); | |||||
| } | |||||
| ptr_b += 2 * pad_k; | |||||
| } | |||||
| if (n & 1) { // TODO: this case seems a overhead. find out whether it's in our | |||||
| // case. | |||||
| ptr_c0 = ptr_c; | |||||
| ptr_a = (bfloat16_t *)A; | |||||
| for (BLASLONG i = 0; i < m / 4; i++) { | |||||
| ptr_a0 = ptr_a; | |||||
| ptr_a += 4 * pad_k; | |||||
| ptr_b0 = ptr_b; | |||||
| INIT_C(0, 0); | |||||
| INIT_C(1, 0); | |||||
| for (BLASLONG p = 0; p < pad_k; p += 8) { | |||||
| ma0 = svld1_bf16(pg16_all, ptr_a0); | |||||
| ma1 = svld1_bf16(pg16_all, ptr_a0 + 16); | |||||
| mb0 = svld1_bf16(pg16_all, ptr_b0); | |||||
| MATMUL(0, 0); | |||||
| MATMUL(1, 0); | |||||
| ptr_a0 += 32; | |||||
| ptr_b0 += 16; | |||||
| } | |||||
| ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); | |||||
| ACCUMULATE_LAST4_TO_FIRST4(1, 0, tmp); | |||||
| ZIP_EVEN_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc0); | |||||
| UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, tmp32, tmp16, vc0); | |||||
| ptr_c0 += 4; | |||||
| } | |||||
| if (m & 2) { | |||||
| ptr_a0 = ptr_a; | |||||
| ptr_a += 2 * pad_k; | |||||
| ptr_b0 = ptr_b; | |||||
| INIT_C(0, 0); | |||||
| for (BLASLONG p = 0; p < pad_k; p += 8) { | |||||
| ma0 = svld1_bf16(pg16_all, ptr_a0); | |||||
| mb0 = svld1_bf16(pg16_all, ptr_b0); | |||||
| MATMUL(0, 0); | |||||
| ptr_a0 += 16; | |||||
| ptr_b0 += 16; | |||||
| } | |||||
| ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); | |||||
| vc0 = svuzp1(mc00, mc00); | |||||
| UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, tmp32, tmp16, vc0); | |||||
| ptr_c0 += 2; | |||||
| } | |||||
| if (m & 1) { | |||||
| ptr_a0 = ptr_a; | |||||
| ptr_b0 = ptr_b; | |||||
| INIT_C(0, 0); | |||||
| for (BLASLONG p = 0; p < pad_k; p += 8) { | |||||
| ma0 = svld1_bf16(pg16_all, ptr_a0); | |||||
| mb0 = svld1_bf16(pg16_all, ptr_b0); | |||||
| MATMUL(0, 0); | |||||
| ptr_a0 += 16; | |||||
| ptr_b0 += 16; | |||||
| } | |||||
| ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); | |||||
| UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, tmp32, tmp16, mc00); | |||||
| } | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| @@ -1,26 +1,50 @@ | |||||
| /*************************************************************************** | |||||
| * Copyright (c) 2025, The OpenBLAS Project | |||||
| * All rights reserved. | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| * modification, are permitted provided that the following conditions are | |||||
| * met: | |||||
| * 1. Redistributions of source code must retain the above copyright | |||||
| * notice, this list of conditions and the following disclaimer. | |||||
| * 2. Redistributions in binary form must reproduce the above copyright | |||||
| * notice, this list of conditions and the following disclaimer in | |||||
| * the documentation and/or other materials provided with the | |||||
| * distribution. | |||||
| * 3. Neither the name of the OpenBLAS project nor the names of | |||||
| * its contributors may be used to endorse or promote products | |||||
| * derived from this software without specific prior written permission. | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
| * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
| * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |||||
| * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |||||
| * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |||||
| * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |||||
| * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | |||||
| * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |||||
| * POSSIBILITY OF SUCH DAMAGE. | |||||
| * *****************************************************************************/ | |||||
| #include "common.h" | #include "common.h" | ||||
| #if defined(BFLOAT16) && defined(BFLOAT16CONVERSION) | #if defined(BFLOAT16) && defined(BFLOAT16CONVERSION) | ||||
| static float | static float | ||||
| bfloat16tof32 (bfloat16 f16) | |||||
| bfloat16tof32 (bfloat16 value) | |||||
| { | { | ||||
| float result = 0; | |||||
| unsigned short* q = (unsigned short*)(&result); | |||||
| #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ | |||||
| q[0] = f16; | |||||
| #else | |||||
| q[1] = f16; | |||||
| #endif | |||||
| blasint one = 1; | |||||
| float result; | |||||
| sbf16tos_(&one, &value, &one, &result, &one); | |||||
| return result; | return result; | ||||
| } | } | ||||
| static bfloat16 f32tobfloat16(float f32) { | |||||
| unsigned short *q = (unsigned short *)(&f32); | |||||
| #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ | |||||
| return q[0]; | |||||
| #else | |||||
| return q[1]; | |||||
| #endif | |||||
| #ifdef BGEMM | |||||
| static bfloat16 f32tobfloat16(float value) { | |||||
| blasint one = 1; | |||||
| bfloat16 result; | |||||
| sbstobf16_(&one, &value, &one, &result, &one); | |||||
| return result; | |||||
| } | } | ||||
| #endif | |||||
| #ifdef BGEMM | #ifdef BGEMM | ||||
| #define ALPHA bfloat16tof32(alpha) | #define ALPHA bfloat16tof32(alpha) | ||||
| @@ -88,7 +88,11 @@ gotoblas_t TABLE_NAME = { | |||||
| ssymv_LTS, ssymv_UTS, | ssymv_LTS, ssymv_UTS, | ||||
| bgemm_kernelTS, bgemm_betaTS, | bgemm_kernelTS, bgemm_betaTS, | ||||
| #if BGEMM_DEFAULT_UNROLL_M != BGEMM_DEFAULT_UNROLL_N | |||||
| bgemm_incopyTS, bgemm_itcopyTS, | bgemm_incopyTS, bgemm_itcopyTS, | ||||
| #else | |||||
| bgemm_oncopyTS, bgemm_otcopyTS, | |||||
| #endif | |||||
| bgemm_oncopyTS, bgemm_otcopyTS, | bgemm_oncopyTS, bgemm_otcopyTS, | ||||
| sbgemm_kernelTS, sbgemm_betaTS, | sbgemm_kernelTS, sbgemm_betaTS, | ||||
| @@ -3593,6 +3593,13 @@ is a big desktop or server with abundant cache rather than a phone or embedded d | |||||
| #define GEMM_PREFERED_SIZE 8 | #define GEMM_PREFERED_SIZE 8 | ||||
| #endif | #endif | ||||
| #undef BGEMM_ALIGN_K | |||||
| #undef BGEMM_DEFAULT_UNROLL_M | |||||
| #undef BGEMM_DEFAULT_UNROLL_N | |||||
| #define BGEMM_ALIGN_K 8 | |||||
| #define BGEMM_DEFAULT_UNROLL_N 4 | |||||
| #define BGEMM_DEFAULT_UNROLL_M 4 | |||||
| #undef SBGEMM_ALIGN_K | #undef SBGEMM_ALIGN_K | ||||
| #undef SBGEMM_DEFAULT_UNROLL_M | #undef SBGEMM_DEFAULT_UNROLL_M | ||||
| #undef SBGEMM_DEFAULT_UNROLL_N | #undef SBGEMM_DEFAULT_UNROLL_N | ||||
| @@ -33,92 +33,49 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| #define BGEMM BLASFUNC(bgemm) | #define BGEMM BLASFUNC(bgemm) | ||||
| #define BGEMM_LARGEST 256 | #define BGEMM_LARGEST 256 | ||||
| typedef union | |||||
| static float float16to32(bfloat16 value) | |||||
| { | { | ||||
| unsigned short v; | |||||
| #if defined(_AIX) | |||||
| struct __attribute__((packed)) | |||||
| #else | |||||
| struct | |||||
| #endif | |||||
| { | |||||
| #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ | |||||
| unsigned short s:1; | |||||
| unsigned short e:8; | |||||
| unsigned short m:7; | |||||
| #else | |||||
| unsigned short m:7; | |||||
| unsigned short e:8; | |||||
| unsigned short s:1; | |||||
| #endif | |||||
| } bits; | |||||
| } bfloat16_bits; | |||||
| typedef union | |||||
| { | |||||
| float v; | |||||
| #if defined(_AIX) | |||||
| struct __attribute__((packed)) | |||||
| #else | |||||
| struct | |||||
| #endif | |||||
| { | |||||
| #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ | |||||
| uint32_t s:1; | |||||
| uint32_t e:8; | |||||
| uint32_t m:23; | |||||
| #else | |||||
| uint32_t m:23; | |||||
| uint32_t e:8; | |||||
| uint32_t s:1; | |||||
| #endif | |||||
| } bits; | |||||
| } float32_bits; | |||||
| float | |||||
| float16to32 (bfloat16_bits f16) | |||||
| { | |||||
| float32_bits f32; | |||||
| f32.bits.s = f16.bits.s; | |||||
| f32.bits.e = f16.bits.e; | |||||
| f32.bits.m = (uint32_t) f16.bits.m << 16; | |||||
| return f32.v; | |||||
| } | |||||
| bfloat16 | |||||
| float32to16 (float32_bits f32) | |||||
| { | |||||
| bfloat16_bits f16; | |||||
| f16.bits.s = f32.bits.s; | |||||
| f16.bits.e = f32.bits.e; | |||||
| f16.bits.m = (f32.bits.m >> 16) & 0x7f; | |||||
| return f16.v; | |||||
| blasint one = 1; | |||||
| float result; | |||||
| sbf16tos_(&one, &value, &one, &result, &one); | |||||
| return result; | |||||
| } | } | ||||
| static float truncate_float(float value) { | static float truncate_float(float value) { | ||||
| bfloat16_bits f16 = (bfloat16_bits)float32to16((float32_bits)value); | |||||
| return float16to32(f16); | |||||
| blasint one = 1; | |||||
| bfloat16 tmp; | |||||
| float result; | |||||
| sbstobf16_(&one, &value, &one, &tmp, &one); | |||||
| sbf16tos_(&one, &tmp, &one, &result, &one); | |||||
| return result; | |||||
| } | } | ||||
| void *malloc_safe(size_t size) { | |||||
| static void *malloc_safe(size_t size) { | |||||
| if (size == 0) | if (size == 0) | ||||
| return malloc(1); | return malloc(1); | ||||
| else | else | ||||
| return malloc(size); | return malloc(size); | ||||
| } | } | ||||
| static is_close(float a, float b, float rtol, float atol) { | |||||
| return fabs(a - b) <= (atol + rtol*fabs(b)); | |||||
| } | |||||
| int | int | ||||
| main (int argc, char *argv[]) | main (int argc, char *argv[]) | ||||
| { | { | ||||
| blasint m, n, k; | blasint m, n, k; | ||||
| int i, j, l; | int i, j, l; | ||||
| blasint x, y; | blasint x, y; | ||||
| blasint one = 1; | |||||
| int ret = 0; | int ret = 0; | ||||
| int loop = BGEMM_LARGEST; | int loop = BGEMM_LARGEST; | ||||
| char transA = 'N', transB = 'N'; | char transA = 'N', transB = 'N'; | ||||
| float alpha = 1.0, beta = 0.0; | float alpha = 1.0, beta = 0.0; | ||||
| bfloat16 alpha_bf16 = float32to16((float32_bits)alpha); | |||||
| bfloat16 beta_bf16 = float32to16((float32_bits)beta); | |||||
| bfloat16 alpha_bf16; | |||||
| sbstobf16_(&one, &alpha, &one, &alpha_bf16, &one); | |||||
| bfloat16 beta_bf16; | |||||
| sbstobf16_(&one, &beta, &one, &beta_bf16, &one); | |||||
| for (x = 0; x <= loop; x++) | for (x = 0; x <= loop; x++) | ||||
| { | { | ||||
| @@ -127,23 +84,20 @@ main (int argc, char *argv[]) | |||||
| float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); | float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); | ||||
| float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); | float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); | ||||
| float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); | float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); | ||||
| bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(m * k * sizeof(bfloat16_bits)); | |||||
| bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits)); | |||||
| bfloat16_bits *CC = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits)); | |||||
| bfloat16 *AA = (bfloat16 *)malloc_safe(m * k * sizeof(bfloat16)); | |||||
| bfloat16 *BB = (bfloat16 *)malloc_safe(k * n * sizeof(bfloat16)); | |||||
| bfloat16 *CC = (bfloat16 *)malloc_safe(k * n * sizeof(bfloat16)); | |||||
| FLOAT *DD = (FLOAT *)malloc_safe(m * n * sizeof(FLOAT)); | FLOAT *DD = (FLOAT *)malloc_safe(m * n * sizeof(FLOAT)); | ||||
| if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || | if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || | ||||
| (DD == NULL) || (CC == NULL)) | (DD == NULL) || (CC == NULL)) | ||||
| return 1; | return 1; | ||||
| bfloat16 atmp,btmp; | |||||
| blasint one=1; | |||||
| for (j = 0; j < m; j++) | for (j = 0; j < m; j++) | ||||
| { | { | ||||
| for (i = 0; i < k; i++) | for (i = 0; i < k; i++) | ||||
| { | { | ||||
| A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | ||||
| sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one); | |||||
| AA[j * k + i].v = atmp; | |||||
| sbstobf16_(&one, &A[j*k+i], &one, &AA[j * k + i], &one); | |||||
| } | } | ||||
| } | } | ||||
| for (j = 0; j < n; j++) | for (j = 0; j < n; j++) | ||||
| @@ -151,8 +105,7 @@ main (int argc, char *argv[]) | |||||
| for (i = 0; i < k; i++) | for (i = 0; i < k; i++) | ||||
| { | { | ||||
| B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | ||||
| sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one); | |||||
| BB[j * k + i].v = btmp; | |||||
| sbstobf16_(&one, &B[j*k+i], &one, &BB[j * k + i], &one); | |||||
| } | } | ||||
| } | } | ||||
| for (y = 0; y < 4; y++) | for (y = 0; y < 4; y++) | ||||
| @@ -168,7 +121,7 @@ main (int argc, char *argv[]) | |||||
| transB = 'T'; | transB = 'T'; | ||||
| } | } | ||||
| memset(CC, 0, m * n * sizeof(bfloat16_bits)); | |||||
| memset(CC, 0, m * n * sizeof(bfloat16)); | |||||
| memset(DD, 0, m * n * sizeof(FLOAT)); | memset(DD, 0, m * n * sizeof(FLOAT)); | ||||
| memset(C, 0, m * n * sizeof(FLOAT)); | memset(C, 0, m * n * sizeof(FLOAT)); | ||||
| @@ -198,10 +151,15 @@ main (int argc, char *argv[]) | |||||
| DD[i * m + j] += | DD[i * m + j] += | ||||
| float16to32 (AA[k * j + l]) * float16to32 (BB[i + l * n]); | float16to32 (AA[k * j + l]) * float16to32 (BB[i + l * n]); | ||||
| } | } | ||||
| if (fabs(float16to32(CC[i * m + j]) - truncate_float(C[i * m + j])) > 2.0) { | |||||
| if (!is_close(float16to32(CC[i * m + j]), truncate_float(C[i * m + j]), 0.01, 0.001)) { | |||||
| printf("Mismatch at i=%d, j=%d, k=%d: CC=%.6f, C=%.6f\n", | |||||
| i, j, k, float16to32(CC[i * m + j]), truncate_float(C[i * m + j])); | |||||
| ret++; | ret++; | ||||
| } | } | ||||
| if (fabs(float16to32(CC[i * m + j]) - truncate_float(DD[i * m + j])) > 1.0) { | |||||
| if (!is_close(float16to32(CC[i * m + j]), truncate_float(DD[i * m + j]), 0.0001, 0.00001)) { | |||||
| printf("Mismatch at i=%d, j=%d, k=%d: CC=%.6f, DD=%.6f\n", | |||||
| i, j, k, float16to32(CC[i * m + j]), truncate_float(DD[i * m + j])); | |||||
| ret++; | ret++; | ||||
| } | } | ||||