Forward GEMM to GEMV when one argument is actually a vectortags/v0.3.28^2
| @@ -274,9 +274,18 @@ endif | |||||
| ifeq ($(ARCH), loongarch64) | ifeq ($(ARCH), loongarch64) | ||||
| SMALL_MATRIX_OPT = 1 | SMALL_MATRIX_OPT = 1 | ||||
| endif | endif | ||||
| ifeq ($(ARCH), arm64) | |||||
| GEMM_GEMV_FORWARD = 1 | |||||
| endif | |||||
| ifeq ($(SMALL_MATRIX_OPT), 1) | ifeq ($(SMALL_MATRIX_OPT), 1) | ||||
| CCOMMON_OPT += -DSMALL_MATRIX_OPT | CCOMMON_OPT += -DSMALL_MATRIX_OPT | ||||
| endif | endif | ||||
| ifeq ($(GEMM_GEMV_FORWARD), 1) | |||||
| ifneq ($(ONLY_CBLAS), 1) | |||||
| CCOMMON_OPT += -DGEMM_GEMV_FORWARD | |||||
| endif | |||||
| endif | |||||
| # This operation is expensive, so execution should be once. | # This operation is expensive, so execution should be once. | ||||
| ifndef GOTOBLAS_MAKEFILE | ifndef GOTOBLAS_MAKEFILE | ||||
| @@ -391,6 +391,13 @@ endif () | |||||
| if (X86_64 OR ${CORE} STREQUAL POWER10) | if (X86_64 OR ${CORE} STREQUAL POWER10) | ||||
| set(SMALL_MATRIX_OPT TRUE) | set(SMALL_MATRIX_OPT TRUE) | ||||
| endif () | endif () | ||||
| if (ARM64) | |||||
| set(GEMM_GEMV_FORWARD TRUE) | |||||
| endif () | |||||
| if (GEMM_GEMV_FORWARD AND NOT ONLY_CBLAS) | |||||
| set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD") | |||||
| endif () | |||||
| if (SMALL_MATRIX_OPT) | if (SMALL_MATRIX_OPT) | ||||
| set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT") | set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT") | ||||
| endif () | endif () | ||||
| @@ -1,4 +1,5 @@ | |||||
| /*********************************************************************/ | /*********************************************************************/ | ||||
| /* Copyright 2024 The OpenBLAS Project */ | |||||
| /* Copyright 2009, 2010 The University of Texas at Austin. */ | /* Copyright 2009, 2010 The University of Texas at Austin. */ | ||||
| /* All rights reserved. */ | /* All rights reserved. */ | ||||
| /* */ | /* */ | ||||
| @@ -47,12 +48,16 @@ | |||||
| #define SMP_THRESHOLD_MIN 65536.0 | #define SMP_THRESHOLD_MIN 65536.0 | ||||
| #ifdef XDOUBLE | #ifdef XDOUBLE | ||||
| #define ERROR_NAME "QGEMM " | #define ERROR_NAME "QGEMM " | ||||
| #define GEMV BLASFUNC(qgemv) | |||||
| #elif defined(DOUBLE) | #elif defined(DOUBLE) | ||||
| #define ERROR_NAME "DGEMM " | #define ERROR_NAME "DGEMM " | ||||
| #define GEMV BLASFUNC(dgemv) | |||||
| #elif defined(BFLOAT16) | #elif defined(BFLOAT16) | ||||
| #define ERROR_NAME "SBGEMM " | #define ERROR_NAME "SBGEMM " | ||||
| #define GEMV BLASFUNC(sbgemv) | |||||
| #else | #else | ||||
| #define ERROR_NAME "SGEMM " | #define ERROR_NAME "SGEMM " | ||||
| #define GEMV BLASFUNC(sgemv) | |||||
| #endif | #endif | ||||
| #else | #else | ||||
| #define SMP_THRESHOLD_MIN 8192.0 | #define SMP_THRESHOLD_MIN 8192.0 | ||||
| @@ -493,6 +498,52 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS | |||||
| args.m, args.n, args.k, args.lda, args.ldb, args.ldc); | args.m, args.n, args.k, args.lda, args.ldb, args.ldc); | ||||
| #endif | #endif | ||||
| #if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) | |||||
| // Check if we can convert GEMM -> GEMV | |||||
| if (args.k != 0) { | |||||
| if (args.n == 1) { | |||||
| blasint inc_x = 1; | |||||
| blasint inc_y = 1; | |||||
| // These were passed in as blasint, but the struct translates them to blaslong | |||||
| blasint m = args.m; | |||||
| blasint n = args.k; | |||||
| blasint lda = args.lda; | |||||
| // Create new transpose parameters | |||||
| char NT = 'N'; | |||||
| if (transa & 1) { | |||||
| NT = 'T'; | |||||
| m = args.k; | |||||
| n = args.m; | |||||
| } | |||||
| if (transb & 1) { | |||||
| inc_x = args.ldb; | |||||
| } | |||||
| GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y); | |||||
| return; | |||||
| } | |||||
| if (args.m == 1) { | |||||
| blasint inc_x = args.lda; | |||||
| blasint inc_y = args.ldc; | |||||
| // These were passed in as blasint, but the struct translates them to blaslong | |||||
| blasint m = args.k; | |||||
| blasint n = args.n; | |||||
| blasint ldb = args.ldb; | |||||
| // Create new transpose parameters | |||||
| char NT = 'T'; | |||||
| if (transa & 1) { | |||||
| inc_x = 1; | |||||
| } | |||||
| if (transb & 1) { | |||||
| NT = 'N'; | |||||
| m = args.n; | |||||
| n = args.k; | |||||
| } | |||||
| GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y); | |||||
| return; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| IDEBUG_START; | IDEBUG_START; | ||||
| FUNCTION_PROFILE_START(); | FUNCTION_PROFILE_START(); | ||||
| @@ -1 +1,4 @@ | |||||
| include $(KERNELDIR)/KERNEL.ARMV8SVE | include $(KERNELDIR)/KERNEL.ARMV8SVE | ||||
| SGEMVTKERNEL = gemv_t_sve.c | |||||
| DGEMVTKERNEL = gemv_t_sve.c | |||||
| @@ -1,5 +1,5 @@ | |||||
| /******************************************************************************* | /******************************************************************************* | ||||
| Copyright (c) 2015, The OpenBLAS Project | |||||
| Copyright (c) 2015, 2024 The OpenBLAS Project | |||||
| All rights reserved. | All rights reserved. | ||||
| Redistribution and use in source and binary forms, with or without | Redistribution and use in source and binary forms, with or without | ||||
| modification, are permitted provided that the following conditions are | modification, are permitted provided that the following conditions are | ||||
| @@ -170,39 +170,48 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| .macro KERNEL_F32_FINALIZE | .macro KERNEL_F32_FINALIZE | ||||
| #if !defined(DOUBLE) | #if !defined(DOUBLE) | ||||
| fadd v1.4s, v1.4s, v2.4s | |||||
| // F8 only has 2 accumulators | |||||
| // so add into those pairs | |||||
| fadd v1.4s, v1.4s, v3.4s | fadd v1.4s, v1.4s, v3.4s | ||||
| fadd v1.4s, v1.4s, v4.4s | |||||
| #else | |||||
| fadd v1.2d, v1.2d, v2.2d | |||||
| fadd v1.2d, v1.2d, v3.2d | |||||
| fadd v1.2d, v1.2d, v4.2d | |||||
| fadd v2.4s, v2.4s, v4.4s | |||||
| #endif | #endif | ||||
| .endm | .endm | ||||
| .macro KERNEL_F4 | |||||
| .macro KERNEL_F8 | |||||
| #if !defined(DOUBLE) | #if !defined(DOUBLE) | ||||
| ld1 {v2.4s}, [A_PTR], #16 | |||||
| ld1 {v3.4s}, [X_PTR], #16 | |||||
| fmla v1.4s, v2.4s, v3.4s | |||||
| #else | |||||
| ld1 {v2.2d}, [A_PTR], #16 | |||||
| ld1 {v3.2d}, [X_PTR], #16 | |||||
| fmla v1.2d, v2.2d, v3.2d | |||||
| ld1 {v4.2d}, [A_PTR], #16 | |||||
| ld1 {v5.2d}, [X_PTR], #16 | |||||
| fmla v1.2d, v4.2d, v5.2d | |||||
| ld1 {v13.4s, v14.4s}, [A_PTR], #32 | |||||
| ld1 {v17.4s, v18.4s}, [X_PTR], #32 | |||||
| fmla v1.4s, v13.4s, v17.4s | |||||
| fmla v2.4s, v14.4s, v18.4s | |||||
| #else | |||||
| ld1 {v13.2d, v14.2d, v15.2d, v16.2d}, [A_PTR], #64 | |||||
| ld1 {v17.2d, v18.2d, v19.2d, v20.2d}, [X_PTR], #64 | |||||
| fmla v1.2d, v13.2d, v17.2d | |||||
| fmla v2.2d, v14.2d, v18.2d | |||||
| fmla v3.2d, v15.2d, v19.2d | |||||
| fmla v4.2d, v16.2d, v20.2d | |||||
| #endif | #endif | ||||
| .endm | .endm | ||||
| .macro KERNEL_F4_FINALIZE | |||||
| .macro KERNEL_F8_FINALIZE | |||||
| #if !defined(DOUBLE) | #if !defined(DOUBLE) | ||||
| ext v2.16b, v1.16b, v1.16b, #8 | |||||
| // Take the top two elements of v1 and | |||||
| // put them into the first two lanes of v3 | |||||
| ext v3.16b, v1.16b, v1.16b, #8 | |||||
| fadd v1.2s, v1.2s, v3.2s | |||||
| ext v4.16b, v2.16b, v2.16b, #8 | |||||
| fadd v2.2s, v2.2s, v4.2s | |||||
| // Final pair | |||||
| fadd v1.2s, v1.2s, v2.2s | fadd v1.2s, v1.2s, v2.2s | ||||
| faddp TEMP, v1.2s | faddp TEMP, v1.2s | ||||
| #else | #else | ||||
| faddp TEMP, v1.2d | faddp TEMP, v1.2d | ||||
| faddp TEMP1, v2.2d | |||||
| faddp TEMP2, v3.2d | |||||
| faddp TEMP3, v4.2d | |||||
| fadd TEMP, TEMP, TEMP1 | |||||
| fadd TEMP2, TEMP2, TEMP3 | |||||
| fadd TEMP, TEMP, TEMP2 | |||||
| #endif | #endif | ||||
| .endm | .endm | ||||
| @@ -258,7 +267,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| asr I, M, #5 | asr I, M, #5 | ||||
| cmp I, xzr | cmp I, xzr | ||||
| beq .Lgemv_t_kernel_F4 | |||||
| beq .Lgemv_t_kernel_F8 | |||||
| .Lgemv_t_kernel_F320: | .Lgemv_t_kernel_F320: | ||||
| @@ -269,24 +278,24 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| KERNEL_F32_FINALIZE | KERNEL_F32_FINALIZE | ||||
| .Lgemv_t_kernel_F4: | |||||
| .Lgemv_t_kernel_F8: | |||||
| ands I, M, #31 | ands I, M, #31 | ||||
| asr I, I, #2 | |||||
| asr I, I, #3 | |||||
| cmp I, xzr | cmp I, xzr | ||||
| beq .Lgemv_t_kernel_F1 | beq .Lgemv_t_kernel_F1 | ||||
| .Lgemv_t_kernel_F40: | |||||
| .Lgemv_t_kernel_F80: | |||||
| KERNEL_F4 | |||||
| KERNEL_F8 | |||||
| subs I, I, #1 | subs I, I, #1 | ||||
| bne .Lgemv_t_kernel_F40 | |||||
| bne .Lgemv_t_kernel_F80 | |||||
| .Lgemv_t_kernel_F1: | .Lgemv_t_kernel_F1: | ||||
| KERNEL_F4_FINALIZE | |||||
| KERNEL_F8_FINALIZE | |||||
| ands I, M, #3 | |||||
| ands I, M, #7 | |||||
| ble .Lgemv_t_kernel_F_END | ble .Lgemv_t_kernel_F_END | ||||
| .Lgemv_t_kernel_F10: | .Lgemv_t_kernel_F10: | ||||
| @@ -59,20 +59,46 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO | |||||
| a_ptr = a; | a_ptr = a; | ||||
| if (inc_x == 1) { | if (inc_x == 1) { | ||||
| svbool_t pg_true = SV_TRUE(); | |||||
| uint64_t sve_size = SV_COUNT(); | uint64_t sve_size = SV_COUNT(); | ||||
| uint64_t sve_size2 = sve_size * 2; | |||||
| BLASLONG m1 = m & -sve_size; | |||||
| BLASLONG m2 = m & -sve_size2; | |||||
| for (j = 0; j < n; j++) { | for (j = 0; j < n; j++) { | ||||
| BLASLONG i = 0; | |||||
| SV_TYPE temp_vec_v2_0 = SV_DUP(0.0); | |||||
| SV_TYPE temp_vec_v2_1 = SV_DUP(0.0); | |||||
| for (; i < m2; i += sve_size2) { | |||||
| SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i); | |||||
| SV_TYPE x_vec0 = svld1(pg_true, x + i); | |||||
| SV_TYPE a_vec1 = svld1(pg_true, a_ptr + i + sve_size); | |||||
| SV_TYPE x_vec1 = svld1(pg_true, x + i + sve_size); | |||||
| temp_vec_v2_0 = svmla_m(pg_true, temp_vec_v2_0, a_vec0, x_vec0); | |||||
| temp_vec_v2_1 = svmla_m(pg_true, temp_vec_v2_1, a_vec1, x_vec1); | |||||
| } | |||||
| SV_TYPE temp_vec_v1 = SV_DUP(0.0); | |||||
| for (; i < m1; i += sve_size) { | |||||
| SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i); | |||||
| SV_TYPE x_vec0 = svld1(pg_true, x + i); | |||||
| temp_vec_v1 = svmla_m(pg_true, temp_vec_v1, a_vec0, x_vec0); | |||||
| } | |||||
| SV_TYPE temp_vec = SV_DUP(0.0); | SV_TYPE temp_vec = SV_DUP(0.0); | ||||
| i = 0; | |||||
| svbool_t pg = SV_WHILE(i, m); | |||||
| while (svptest_any(SV_TRUE(), pg)) { | |||||
| for (; i < m; i += sve_size) { | |||||
| svbool_t pg = SV_WHILE(i, m); | |||||
| SV_TYPE a_vec = svld1(pg, a_ptr + i); | SV_TYPE a_vec = svld1(pg, a_ptr + i); | ||||
| SV_TYPE x_vec = svld1(pg, x + i); | SV_TYPE x_vec = svld1(pg, x + i); | ||||
| temp_vec = svmla_m(pg, temp_vec, a_vec, x_vec); | temp_vec = svmla_m(pg, temp_vec, a_vec, x_vec); | ||||
| i += sve_size; | |||||
| pg = SV_WHILE(i, m); | |||||
| } | } | ||||
| temp = svaddv(SV_TRUE(), temp_vec); | |||||
| y[iy] += alpha * temp; | |||||
| y[iy] += alpha * ( | |||||
| (svaddv(SV_TRUE(), temp_vec_v2_0) + svaddv(SV_TRUE(), temp_vec)) + | |||||
| (svaddv(SV_TRUE(), temp_vec_v2_1) + svaddv(SV_TRUE(), temp_vec_v1)) | |||||
| ); | |||||
| iy += inc_y; | iy += inc_y; | ||||
| a_ptr += lda; | a_ptr += lda; | ||||
| } | } | ||||