diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 9edf3d6ea..041582892 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -236,9 +236,12 @@ In chronological order: * Annop Wongwathanarat * [2025-01-10] Add thread throttling profile for SGEMM on NEOVERSEV1 * [2025-01-21] Optimize gemv_t_sve_v1x3 kernel + * [2025-02-26] Add sbgemv_t_bfdot kernel -* Marek Michalowski +* Marek Michalowski * [2025-01-21] Add thread throttling profile for SGEMV on `NEOVERSEV1` + * [2025-02-18] Add thread throttling profile for SGEMM on `NEOVERSEV2` + * [2025-02-19] Add thread throttling profile for SGEMV on `NEOVERSEV2` * Ye Tao * [2025-02-03] Optimize SBGEMM kernel on NEOVERSEV1 diff --git a/driver/others/dynamic_arm64.c b/driver/others/dynamic_arm64.c index 5d48f6806..31821ae78 100644 --- a/driver/others/dynamic_arm64.c +++ b/driver/others/dynamic_arm64.c @@ -162,7 +162,7 @@ extern gotoblas_t gotoblas_ARMV9SME; extern gotoblas_t gotoblas_THUNDERX3T110; #endif -#define gotoblas_NEOVERSEV2 gotoblas_NEOVERSEV1 +#define gotoblas_NEOVERSEV2 gotoblas_NEOVERSEN2 extern void openblas_warning(int verbose, const char * msg); #define FALLBACK_VERBOSE 1 diff --git a/interface/gemm.c b/interface/gemm.c index 2cd7d7b5c..d36925629 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -177,6 +177,7 @@ static int init_amxtile_permission() { } #endif +#ifdef SMP #ifdef DYNAMIC_ARCH extern char* gotoblas_corename(void); #endif @@ -198,14 +199,37 @@ static inline int get_gemm_optimal_nthreads_neoversev1(double MNK, int ncpu) { } #endif +#if defined(DYNAMIC_ARCH) || defined(NEOVERSEV2) +static inline int get_gemm_optimal_nthreads_neoversev2(double MNK, int ncpu) { + return + MNK < 125000L ? 1 + : MNK < 1092727L ? MIN(ncpu, 6) + : MNK < 2628072L ? MIN(ncpu, 8) + : MNK < 8000000L ? MIN(ncpu, 12) + : MNK < 20346417L ? MIN(ncpu, 16) + : MNK < 57066625L ? MIN(ncpu, 24) + : MNK < 91125000L ? MIN(ncpu, 28) + : MNK < 238328000L ? MIN(ncpu, 40) + : MNK < 454756609L ? MIN(ncpu, 48) + : MNK < 857375000L ? MIN(ncpu, 56) + : MNK < 1073741824L ? MIN(ncpu, 64) + : ncpu; +} +#endif + static inline int get_gemm_optimal_nthreads(double MNK) { int ncpu = num_cpu_avail(3); #if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu); +#elif defined(NEOVERSEV2) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) + return get_gemm_optimal_nthreads_neoversev2(MNK, ncpu); #elif defined(DYNAMIC_ARCH) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) if (strcmp(gotoblas_corename(), "neoversev1") == 0) { return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu); } + if (strcmp(gotoblas_corename(), "neoversev2") == 0) { + return get_gemm_optimal_nthreads_neoversev2(MNK, ncpu); + } #endif if ( MNK <= (SMP_THRESHOLD_MIN * (double) GEMM_MULTITHREAD_THRESHOLD) ) { return 1; @@ -219,6 +243,7 @@ static inline int get_gemm_optimal_nthreads(double MNK) { } } } +#endif #ifndef CBLAS diff --git a/interface/gemv.c b/interface/gemv.c index f91f364ee..533ea3a56 100644 --- a/interface/gemv.c +++ b/interface/gemv.c @@ -63,6 +63,7 @@ static int (*gemv_thread[])(BLASLONG, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT }; #endif +#ifdef SMP #ifdef DYNAMIC_ARCH extern char* gotoblas_corename(void); #endif @@ -77,14 +78,30 @@ static inline int get_gemv_optimal_nthreads_neoversev1(BLASLONG MN, int ncpu) { } #endif +#if defined(DYNAMIC_ARCH) || defined(NEOVERSEV2) +static inline int get_gemv_optimal_nthreads_neoversev2(BLASLONG MN, int ncpu) { + return + MN < 24964L ? 1 + : MN < 65536L ? MIN(ncpu, 8) + : MN < 262144L ? MIN(ncpu, 32) + : MN < 1638400L ? MIN(ncpu, 64) + : ncpu; +} +#endif + static inline int get_gemv_optimal_nthreads(BLASLONG MN) { int ncpu = num_cpu_avail(3); #if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) return get_gemv_optimal_nthreads_neoversev1(MN, ncpu); +#elif defined(NEOVERSEV2) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) + return get_gemv_optimal_nthreads_neoversev2(MN, ncpu); #elif defined(DYNAMIC_ARCH) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) if (strcmp(gotoblas_corename(), "neoversev1") == 0) { return get_gemv_optimal_nthreads_neoversev1(MN, ncpu); } + if (strcmp(gotoblas_corename(), "neoversev2") == 0) { + return get_gemv_optimal_nthreads_neoversev2(MN, ncpu); + } #endif if ( MN < 115200L * GEMM_MULTITHREAD_THRESHOLD ) @@ -92,6 +109,7 @@ static inline int get_gemv_optimal_nthreads(BLASLONG MN) { else return num_cpu_avail(2); } +#endif #ifndef CBLAS @@ -232,13 +250,6 @@ void CNAME(enum CBLAS_ORDER order, if (alpha == ZERO) return; -#if 0 -/* this optimization causes stack corruption on x86_64 under OSX, Windows and FreeBSD */ - if (trans == 0 && incx == 1 && incy == 1 && m*n < 2304 *GEMM_MULTITHREAD_THRESHOLD) { - GEMV_N(m, n, 0, alpha, a, lda, x, incx, y, incy, NULL); - return; - } -#endif IDEBUG_START; FUNCTION_PROFILE_START(); diff --git a/kernel/arm64/KERNEL.NEOVERSEN2 b/kernel/arm64/KERNEL.NEOVERSEN2 index 2f7400113..e4e1cfde3 100644 --- a/kernel/arm64/KERNEL.NEOVERSEN2 +++ b/kernel/arm64/KERNEL.NEOVERSEN2 @@ -198,3 +198,4 @@ SBGEMMINCOPYOBJ = sbgemm_incopy$(TSUFFIX).$(SUFFIX) SBGEMMITCOPYOBJ = sbgemm_itcopy$(TSUFFIX).$(SUFFIX) SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX) SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX) +SBGEMVTKERNEL = sbgemv_t_bfdot.c \ No newline at end of file diff --git a/kernel/arm64/KERNEL.NEOVERSEV1 b/kernel/arm64/KERNEL.NEOVERSEV1 index d14993544..bacedf8cf 100644 --- a/kernel/arm64/KERNEL.NEOVERSEV1 +++ b/kernel/arm64/KERNEL.NEOVERSEV1 @@ -17,4 +17,6 @@ SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX) SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX) SBGEMVNKERNEL = sbgemv_n_neon.c +SBGEMVTKERNEL = sbgemv_t_bfdot.c + endif \ No newline at end of file diff --git a/kernel/arm64/KERNEL.NEOVERSEV2 b/kernel/arm64/KERNEL.NEOVERSEV2 index bc5999097..4d866f858 100644 --- a/kernel/arm64/KERNEL.NEOVERSEV2 +++ b/kernel/arm64/KERNEL.NEOVERSEV2 @@ -1 +1,5 @@ include $(KERNELDIR)/KERNEL.ARMV8SVE + +ifeq ($(BUILD_BFLOAT16), 1) +SBGEMVTKERNEL = sbgemv_t_bfdot.c +endif \ No newline at end of file diff --git a/kernel/arm64/sbgemv_t_bfdot.c b/kernel/arm64/sbgemv_t_bfdot.c new file mode 100644 index 000000000..0751690fc --- /dev/null +++ b/kernel/arm64/sbgemv_t_bfdot.c @@ -0,0 +1,207 @@ +/*************************************************************************** +Copyright (c) 2025, The OpenBLAS Project +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + 3. Neither the name of the OpenBLAS project nor the names of + its contributors may be used to endorse or promote products + derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" + +static inline float bf16_to_fp32(bfloat16 bf16) { + uint32_t fp32 = (uint32_t)bf16 << 16; + return *((float*)&fp32); +} + +int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy) +{ + if (m < 1 || n < 1) return(0); + BLASLONG i; + BLASLONG ix,iy; + BLASLONG j; + bfloat16_t *a_ptr; + bfloat16_t *x_ptr; + float *y_ptr; + float temp; + + iy = 0; + a_ptr = (bfloat16_t*)(a); + x_ptr = (bfloat16_t*)(x); + + if (incx == 1) { + BLASLONG width = n / 4; + + bfloat16_t *a0_ptr = a_ptr + lda * width * 0; + bfloat16_t *a1_ptr = a_ptr + lda * width * 1; + bfloat16_t *a2_ptr = a_ptr + lda * width * 2; + bfloat16_t *a3_ptr = a_ptr + lda * width * 3; + + float *y0_ptr = y + incy * width * 0; + float *y1_ptr = y + incy * width * 1; + float *y2_ptr = y + incy * width * 2; + float *y3_ptr = y + incy * width * 3; + + for (j = 0; j < width; j++) { + float32x4_t temp0_vec = vdupq_n_f32(0.0f); + float32x4_t temp1_vec = vdupq_n_f32(0.0f); + float32x4_t temp2_vec = vdupq_n_f32(0.0f); + float32x4_t temp3_vec = vdupq_n_f32(0.0f); + + i = 0; + while (i + 7 < m) { + bfloat16x8_t x_vec = vld1q_bf16(x_ptr + i); + + bfloat16x8_t a0_vec = vld1q_bf16(a0_ptr + i); + bfloat16x8_t a1_vec = vld1q_bf16(a1_ptr + i); + bfloat16x8_t a2_vec = vld1q_bf16(a2_ptr + i); + bfloat16x8_t a3_vec = vld1q_bf16(a3_ptr + i); + + temp0_vec = vbfdotq_f32(temp0_vec, a0_vec, x_vec); + temp1_vec = vbfdotq_f32(temp1_vec, a1_vec, x_vec); + temp2_vec = vbfdotq_f32(temp2_vec, a2_vec, x_vec); + temp3_vec = vbfdotq_f32(temp3_vec, a3_vec, x_vec); + + i += 8; + } + if (i + 3 < m) { + float32x2_t t0 = vdup_n_f32(0.0f); + float32x2_t t1 = vdup_n_f32(0.0f); + float32x2_t t2 = vdup_n_f32(0.0f); + float32x2_t t3 = vdup_n_f32(0.0f); + + bfloat16x4_t x_vec = vld1_bf16(x_ptr + i); + + bfloat16x4_t a0_vec = vld1_bf16(a0_ptr + i); + bfloat16x4_t a1_vec = vld1_bf16(a1_ptr + i); + bfloat16x4_t a2_vec = vld1_bf16(a2_ptr + i); + bfloat16x4_t a3_vec = vld1_bf16(a3_ptr + i); + + t0 = vbfdot_f32(t0, a0_vec, x_vec); + t1 = vbfdot_f32(t1, a1_vec, x_vec); + t2 = vbfdot_f32(t2, a2_vec, x_vec); + t3 = vbfdot_f32(t3, a3_vec, x_vec); + + float32x2_t temp0_vec_low = vget_low_f32(temp0_vec); + float32x2_t temp1_vec_low = vget_low_f32(temp1_vec); + float32x2_t temp2_vec_low = vget_low_f32(temp2_vec); + float32x2_t temp3_vec_low = vget_low_f32(temp3_vec); + + temp0_vec = vcombine_f32(vadd_f32(t0, temp0_vec_low), vget_high_f32(temp0_vec)); + temp1_vec = vcombine_f32(vadd_f32(t1, temp1_vec_low), vget_high_f32(temp1_vec)); + temp2_vec = vcombine_f32(vadd_f32(t2, temp2_vec_low), vget_high_f32(temp2_vec)); + temp3_vec = vcombine_f32(vadd_f32(t3, temp3_vec_low), vget_high_f32(temp3_vec)); + + i += 4; + } + if (beta == 0.0f) { + y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec); + y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec); + y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec); + y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec); + } + else { + y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y0_ptr[iy]; + y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec) + beta * y1_ptr[iy]; + y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec) + beta * y2_ptr[iy]; + y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec) + beta * y3_ptr[iy]; + } + + for (; i < m; ++i) { + y0_ptr[iy] += alpha * a0_ptr[i] * x_ptr[i]; + y1_ptr[iy] += alpha * a1_ptr[i] * x_ptr[i]; + y2_ptr[iy] += alpha * a2_ptr[i] * x_ptr[i]; + y3_ptr[iy] += alpha * a3_ptr[i] * x_ptr[i]; + } + + iy += incy; + + a0_ptr += lda; + a1_ptr += lda; + a2_ptr += lda; + a3_ptr += lda; + } + + a_ptr = a3_ptr; + y_ptr = y3_ptr; + for (j = width * 4; j < n; j++) { + float32x4_t temp0_vec = vdupq_n_f32(0.0f); + i = 0; + while (i + 7 < m) { + bfloat16x8_t x_vec = vld1q_bf16(x_ptr + i); + bfloat16x8_t a0_vec = vld1q_bf16(a_ptr + i); + temp0_vec = vbfdotq_f32(temp0_vec, a0_vec, x_vec); + + i += 8; + } + if (i + 3 < m) { + float32x2_t t0 = vdup_n_f32(0.0f); + bfloat16x4_t x_vec = vld1_bf16(x_ptr + i); + bfloat16x4_t a0_vec = vld1_bf16(a_ptr + i); + + t0 = vbfdot_f32(t0, a0_vec, x_vec); + float32x2_t temp0_vec_low = vget_low_f32(temp0_vec); + temp0_vec = vcombine_f32(vadd_f32(t0, temp0_vec_low), vget_high_f32(temp0_vec)); + + i += 4; + } + if (beta == 0.0f) { + y_ptr[iy] = alpha * vaddvq_f32(temp0_vec); + } + else { + y_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y_ptr[iy]; + } + + for (; i < m; ++i) { + y_ptr[iy] += alpha * a_ptr[i] * x_ptr[i]; + } + + iy += incy; + + a_ptr += lda; + } + return(0); + } + + for (j = 0; j < n; j++) { + temp = 0.0; + ix = 0; + for (i = 0; i < m; i++) { + temp += bf16_to_fp32(a[i]) * bf16_to_fp32(x[ix]); + ix += incx; + } + if (beta == 0.0f) { + y[iy] = alpha * temp; + } + else { + y[iy] = alpha * temp + beta * y[iy]; + } + iy += incy; + a += lda; + } + return (0); +} diff --git a/kernel/power/scal.S b/kernel/power/scal.S index eceb9fe8e..8fd175d18 100644 --- a/kernel/power/scal.S +++ b/kernel/power/scal.S @@ -51,7 +51,7 @@ #else #define X r7 #define INCX r8 -#define FLAG r12 +#define FLAG r11 #endif #endif @@ -63,7 +63,7 @@ #else #define X r7 #define INCX r8 -#define FLAG r12 +#define FLAG r11 #endif #endif @@ -91,7 +91,7 @@ fcmpu cr0, FZERO, ALPHA bne- cr0, LL(A1I1) - LDLONG FLAG, 48+64+8(SP) + LDLONG FLAG, 104(SP) cmpwi cr0, FLAG, 1 beq- cr0, LL(A1I1)