diff --git a/.gitignore b/.gitignore index c0885d466..1807d4496 100644 --- a/.gitignore +++ b/.gitignore @@ -81,6 +81,7 @@ test/ZBLAT2.SUMM test/ZBLAT3.SUMM test/ZBLAT3_3M.SUMM test/SHBLAT3.SUMM +test/SBBLAT2.SUMM test/SBBLAT3.SUMM test/BBLAT3.SUMM test/cblat1 @@ -97,6 +98,7 @@ test/sblat3 test/sblat3_3m test/test_shgemm test/test_sbgemm +test/test_sbgemv test/test_bgemm test/zblat1 test/zblat2 diff --git a/test/Makefile b/test/Makefile index cd8006c04..144738eb2 100644 --- a/test/Makefile +++ b/test/Makefile @@ -119,6 +119,9 @@ endif endif endif +ifeq ($(BUILD_BFLOAT16), 1) +B2 = test_sbgemv +endif ifeq ($(BUILD_SINGLE),1) S2=sblat2 endif @@ -132,11 +135,15 @@ ifeq ($(BUILD_COMPLEX16),1) Z2=zblat2 endif -level2: $(S2) $(D2) $(C2) $(Z2) +level2: $(B2) $(S2) $(D2) $(C2) $(Z2) ifneq ($(CROSS), 1) rm -f ?BLAT2.SUMM +ifeq ($(BUILD_BFLOAT16),1) + OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_sbgemv > SBBLAT2.SUMM + @$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0 +endif ifeq ($(BUILD_SINGLE),1) OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./sblat2 < ./sblat2.dat @$(GREP) -q FATAL SBLAT2.SUMM && cat SBLAT2.SUMM || exit 0 @@ -156,6 +163,10 @@ endif ifdef SMP rm -f ?BLAT2.SUMM ifeq ($(USE_OPENMP), 1) +ifeq ($(BUILD_BFLOAT16),1) + OMP_NUM_THREADS=2 ./test_sbgemv > SBBLAT2.SUMM + @$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0 +endif ifeq ($(BUILD_SINGLE),1) OMP_NUM_THREADS=2 ./sblat2 < ./sblat2.dat @$(GREP) -q FATAL SBLAT2.SUMM && cat SBLAT2.SUMM || exit 0 @@ -173,6 +184,10 @@ ifeq ($(BUILD_COMPLEX16),1) @$(GREP) -q FATAL ZBLAT2.SUMM && cat ZBLAT2.SUMM || exit 0 endif else +ifeq ($(BUILD_BFLOAT16),1) + OMP_NUM_THREADS=2 ./test_sbgemv > SBBLAT2.SUMM + @$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0 +endif ifeq ($(BUILD_SINGLE),1) OPENBLAS_NUM_THREADS=2 ./sblat2 < ./sblat2.dat @$(GREP) -q FATAL SBLAT2.SUMM && cat SBLAT2.SUMM || exit 0 @@ -195,7 +210,7 @@ endif ifeq ($(BUILD_BFLOAT16),1) BF3= test_bgemm -B3= test_sbgemm +B3 = test_sbgemm endif ifeq ($(BUILD_SINGLE),1) S3=sblat3 @@ -408,6 +423,9 @@ test_bgemm : compare_sgemm_bgemm.c test_helpers.h ../$(LIBNAME) test_sbgemm : compare_sgemm_sbgemm.c test_helpers.h ../$(LIBNAME) $(CC) $(CLDFLAGS) -o test_sbgemm compare_sgemm_sbgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) + +test_sbgemv : compare_sgemv_sbgemv.c ../$(LIBNAME) + $(CC) $(CLDFLAGS) -o test_sbgemv compare_sgemv_sbgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) endif ifeq ($(BUILD_COMPLEX),1) @@ -426,7 +444,7 @@ clean: @rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \ sblat1 dblat1 cblat1 zblat1 \ sblat2 dblat2 cblat2 zblat2 \ - test_bgemm test_sbgemm sblat3 dblat3 cblat3 zblat3 \ + test_bgemm test_sbgemm test_sbgemv sblat3 dblat3 cblat3 zblat3 \ sblat1p dblat1p cblat1p zblat1p \ sblat2p dblat2p cblat2p zblat2p \ sblat3p dblat3p cblat3p zblat3p \ diff --git a/test/compare_sgemm_bgemm.c b/test/compare_sgemm_bgemm.c index 8ece63841..bc8a0b468 100644 --- a/test/compare_sgemm_bgemm.c +++ b/test/compare_sgemm_bgemm.c @@ -158,6 +158,7 @@ main (int argc, char *argv[]) if (ret != 0) { fprintf (stderr, "FATAL ERROR BGEMM - Return code: %d\n", ret); - return ret; } + + return ret; } diff --git a/test/compare_sgemm_sbgemm.c b/test/compare_sgemm_sbgemm.c index 4fa24b9ce..489222516 100644 --- a/test/compare_sgemm_sbgemm.c +++ b/test/compare_sgemm_sbgemm.c @@ -141,87 +141,7 @@ main (int argc, char *argv[]) if (ret != 0) { fprintf (stderr, "FATAL ERROR SBGEMM - Return code: %d\n", ret); - return ret; } - for (beta = 0; beta < 3; beta += 1) { - for (alpha = 0; alpha < 3; alpha += 1) { - for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one. - for (x = 1; x <= loop; x++) - { - k = (x == 0) ? 0 : l + 1; - float *A = (float *)malloc_safe(x * x * sizeof(FLOAT)); - float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l); - float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l); - bfloat16 *AA = (bfloat16 *)malloc_safe(x * x * sizeof(bfloat16)); - bfloat16 *BB = (bfloat16 *)malloc_safe(x * sizeof(bfloat16) << l); - float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); - float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l); - if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || - (DD == NULL) || (CC == NULL)) - return 1; - blasint one = 1; - - for (j = 0; j < x; j++) - { - for (i = 0; i < x; i++) - { - A[j * x + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - sbstobf16_(&one, &A[j*x+i], &one, &AA[j * x + i], &one); - } - B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - sbstobf16_(&one, &B[j << l], &one, &BB[j << l], &one); - - CC[j << l] = C[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - } - - for (y = 0; y < 2; y++) - { - if (y == 0) { - transA = 'N'; - } else { - transA = 'T'; - } - - memset(CC, 0, x * sizeof(FLOAT) << l); - memset(DD, 0, x * sizeof(FLOAT)); - memset(C, 0, x * sizeof(FLOAT) << l); - - SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k); - SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k); - - for (int i = 0; i < x; i ++) DD[i] *= beta; - - for (j = 0; j < x; j++) - for (i = 0; i < x; i++) - if (transA == 'N') { - DD[i] += alpha * float16to32 (AA[j * x + i]) * float16to32 (BB[j << l]); - } else if (transA == 'T') { - DD[j] += alpha * float16to32 (AA[j * x + i]) * float16to32 (BB[i << l]); - } - - for (j = 0; j < x; j++) { - if (!is_close(CC[j << l], C[j << l], 0.01, 0.001)) { - ret++; - } - if (!is_close(CC[j << l], DD[j], 0.001, 0.0001)) { - ret++; - } - } - } - free(A); - free(B); - free(C); - free(AA); - free(BB); - free(DD); - free(CC); - } // x - } // l - } // alpha - } // beta - - if (ret != 0) - fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret); return ret; } diff --git a/test/compare_sgemv_sbgemv.c b/test/compare_sgemv_sbgemv.c new file mode 100644 index 000000000..5fa2d5f66 --- /dev/null +++ b/test/compare_sgemv_sbgemv.c @@ -0,0 +1,128 @@ +/*************************************************************************** +Copyright (c) 2020,2025 The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ +#include +#include +#include "../common.h" + +#include "test_helpers.h" + +#define SGEMV BLASFUNC(sgemv) +#define SBGEMV BLASFUNC(sbgemv) +#define SBGEMV_LARGEST 256 + +int +main (int argc, char *argv[]) +{ + blasint k; + int i, j, l; + blasint x, y; + int ret = 0; + int loop = SBGEMV_LARGEST; + char transA = 'N'; + float alpha = 1.0, beta = 0.0; + + for (beta = 0; beta < 3; beta += 1) { + for (alpha = 0; alpha < 3; alpha += 1) { + for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one. + for (x = 1; x <= loop; x++) + { + k = (x == 0) ? 0 : l + 1; + float *A = (float *)malloc_safe(x * x * sizeof(FLOAT)); + float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l); + float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l); + bfloat16 *AA = (bfloat16 *)malloc_safe(x * x * sizeof(bfloat16)); + bfloat16 *BB = (bfloat16 *)malloc_safe(x * sizeof(bfloat16) << l); + float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); + float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l); + if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || + (DD == NULL) || (CC == NULL)) + return 1; + blasint one = 1; + + for (j = 0; j < x; j++) + { + for (i = 0; i < x; i++) + { + A[j * x + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + sbstobf16_(&one, &A[j*x+i], &one, &AA[j * x + i], &one); + } + B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + sbstobf16_(&one, &B[j << l], &one, &BB[j << l], &one); + + CC[j << l] = C[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + } + + for (y = 0; y < 2; y++) + { + if (y == 0) { + transA = 'N'; + } else { + transA = 'T'; + } + + memset(CC, 0, x * sizeof(FLOAT) << l); + memset(DD, 0, x * sizeof(FLOAT)); + memset(C, 0, x * sizeof(FLOAT) << l); + + SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k); + SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k); + + for (int i = 0; i < x; i ++) DD[i] *= beta; + + for (j = 0; j < x; j++) + for (i = 0; i < x; i++) + if (transA == 'N') { + DD[i] += alpha * float16to32 (AA[j * x + i]) * float16to32 (BB[j << l]); + } else if (transA == 'T') { + DD[j] += alpha * float16to32 (AA[j * x + i]) * float16to32 (BB[i << l]); + } + + for (j = 0; j < x; j++) { + if (!is_close(CC[j << l], C[j << l], 0.01, 0.001)) { + ret++; + } + if (!is_close(CC[j << l], DD[j], 0.001, 0.0001)) { + ret++; + } + } + } + free(A); + free(B); + free(C); + free(AA); + free(BB); + free(DD); + free(CC); + } // x + } // l + } // alpha + } // beta + + if (ret != 0) + fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret); + return ret; +}