From abe9d38f742432ede3b0f40239effc381726860b Mon Sep 17 00:00:00 2001 From: Ye Tao Date: Wed, 21 May 2025 14:52:56 +0000 Subject: [PATCH] add generic bgemm kernel and its test file --- kernel/arm64/bgemm_beta.c | 38 ++++++ kernel/arm64/bgemm_kernel.c | 37 +++++ kernel/generic/bgemmkernel_2x2.c | 227 +++++++++++++++++++++++++++++++ kernel/generic/gemm_beta.c | 34 +++++ test/compare_sgemm_bgemm.c | 138 +++++++++++++++++++ 5 files changed, 474 insertions(+) create mode 100644 kernel/arm64/bgemm_beta.c create mode 100644 kernel/arm64/bgemm_kernel.c create mode 100644 kernel/generic/bgemmkernel_2x2.c create mode 100644 test/compare_sgemm_bgemm.c diff --git a/kernel/arm64/bgemm_beta.c b/kernel/arm64/bgemm_beta.c new file mode 100644 index 000000000..40c17d780 --- /dev/null +++ b/kernel/arm64/bgemm_beta.c @@ -0,0 +1,38 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include "common.h" + +#include + +int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta, IFLOAT *dummy2, + BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, FLOAT *c, + BLASLONG ldc) { + printf("running bgemm_beta...\n"); + return 0; +}; diff --git a/kernel/arm64/bgemm_kernel.c b/kernel/arm64/bgemm_kernel.c new file mode 100644 index 000000000..223cdc39c --- /dev/null +++ b/kernel/arm64/bgemm_kernel.c @@ -0,0 +1,37 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include +#include "common.h" + +int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 alpha, IFLOAT *A, IFLOAT *B, + FLOAT *C, BLASLONG ldc) { + printf("running bgemm_kernel...\n"); + return 0; +} + diff --git a/kernel/generic/bgemmkernel_2x2.c b/kernel/generic/bgemmkernel_2x2.c new file mode 100644 index 000000000..5fe8d3255 --- /dev/null +++ b/kernel/generic/bgemmkernel_2x2.c @@ -0,0 +1,227 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include "common.h" + +static float bfloat16tof32(bfloat16 f16) { + float result = 0; + unsigned short *q = (unsigned short *)(&result); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + q[0] = f16; +#else + q[1] = f16; +#endif + return result; +} + +static bfloat16 f32tobfloat16(float f32) { + unsigned short *q = (unsigned short *)(&f32); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + return q[0]; +#else + return q[1]; +#endif +} + +#define BF16TOF32(x) (bfloat16tof32(x)) +#define F32TOBF16(x) (f32tobfloat16(x)) + +int CNAME(BLASLONG bm, BLASLONG bn, BLASLONG bk, FLOAT alpha, IFLOAT *ba, + IFLOAT *bb, FLOAT *C, BLASLONG ldc) { + BLASLONG i, j, k; + FLOAT *C0, *C1; // bfloat16 + IFLOAT *ptrba, *ptrbb; + float res0, res1, res2, res3; + float load0, load1, load2, load3, load4, load5, load6, load7; + + float alpha_ = BF16TOF32(alpha); + + for (j = 0; j < bn / 2; j += 1) { + C0 = C; + C1 = C0 + ldc; + ptrba = ba; + for (i = 0; i < bm / 2; i += 1) { + ptrbb = bb; + res0 = 0; + res1 = 0; + res2 = 0; + res3 = 0; + for (k = 0; k < bk / 4; k += 1) { + load0 = BF16TOF32(ptrba[2 * 0 + 0]); + load2 = BF16TOF32(ptrba[2 * 0 + 1]); + load4 = BF16TOF32(ptrba[2 * 1 + 0]); + load6 = BF16TOF32(ptrba[2 * 1 + 1]); + + load1 = BF16TOF32(ptrbb[2 * 0 + 0]); + load3 = BF16TOF32(ptrbb[2 * 0 + 1]); + load5 = BF16TOF32(ptrbb[2 * 1 + 0]); + load7 = BF16TOF32(ptrbb[2 * 1 + 1]); + + res0 = res0 + load0 * load1; + res1 = res1 + load2 * load1; + res2 = res2 + load0 * load3; + res3 = res3 + load2 * load3; + + res0 = res0 + load4 * load5; + res1 = res1 + load6 * load5; + res2 = res2 + load4 * load7; + res3 = res3 + load6 * load7; + + load0 = BF16TOF32(ptrba[2 * 2 + 0]); + load2 = BF16TOF32(ptrba[2 * 2 + 1]); + load4 = BF16TOF32(ptrba[2 * 3 + 0]); + load6 = BF16TOF32(ptrba[2 * 3 + 1]); + + load1 = BF16TOF32(ptrbb[2 * 2 + 0]); + load3 = BF16TOF32(ptrbb[2 * 2 + 1]); + load5 = BF16TOF32(ptrbb[2 * 3 + 0]); + load7 = BF16TOF32(ptrbb[2 * 3 + 1]); + + res0 = res0 + load0 * load1; + res1 = res1 + load2 * load1; + res2 = res2 + load0 * load3; + res3 = res3 + load2 * load3; + + res0 = res0 + load4 * load5; + res1 = res1 + load6 * load5; + res2 = res2 + load4 * load7; + res3 = res3 + load6 * load7; + } + + for (k = 0; k < (bk & 3); k += 1) { + load0 = BF16TOF32(ptrba[2 * 0 + 0]); + load2 = BF16TOF32(ptrba[2 * 0 + 1]); + load1 = BF16TOF32(ptrbb[2 * 0 + 0]); + load3 = BF16TOF32(ptrbb[2 * 0 + 1]); + + res0 = res0 + load0 * load1; + res1 = res1 + load2 * load1; + res2 = res2 + load0 * load3; + res3 = res3 + load2 * load3; + + ptrba = ptrba + 2; + ptrbb = ptrbb + 2; + } + + res0 = res0 * alpha_ + BF16TOF32(C0[0]); + res1 = res1 * alpha_ + BF16TOF32(C0[1]); + res2 = res2 * alpha_ + BF16TOF32(C1[0]); + res3 = res3 * alpha_ + BF16TOF32(C1[1]); + + C0[0] = F32TOBF16(res0); + C0[1] = F32TOBF16(res1); + C1[0] = F32TOBF16(res2); + C1[1] = F32TOBF16(res3); + + C0 = C0 + 2; + C1 = C1 + 2; + } + + for (i = 0; i < (bm & 1); i += 1) { + ptrbb = bb; + res0 = 0; + res1 = 0; + for (k = 0; k < bk; k += 1) { + load0 = BF16TOF32(ptrba[0 + 0]); + load1 = BF16TOF32(ptrbb[2 * 0 + 0]); + load2 = BF16TOF32(ptrbb[2 * 0 + 1]); + + res0 = res0 + load0 * load1; + res1 = res1 + load0 * load2; + + ptrba = ptrba + 1; + ptrbb = ptrbb + 2; + } + + res0 = res0 * alpha_ + BF16TOF32(C0[0]); + res1 = res1 * alpha_ + BF16TOF32(C1[0]); + + C0[0] = res0; + C1[0] = res1; + + C0 = C0 + 1; + C1 = C1 + 1; + } + + k = (bk << 1); + bb = bb + k; + i = (ldc << 1); + C = C + i; + } + + for (j = 0; j < (bn & 1); j += 1) { + C0 = C; + ptrba = ba; + for (i = 0; i < bm / 2; i += 1) { + ptrbb = bb; + res0 = 0; + res1 = 0; + for (k = 0; k < bk; k += 1) { + load0 = BF16TOF32(ptrba[2 * 0 + 0]); + load2 = BF16TOF32(ptrba[2 * 0 + 1]); + + load1 = BF16TOF32(ptrbb[0 + 0]); + + res0 = res0 + load0 * load1; + res1 = res1 + load2 * load1; + + ptrba = ptrba + 2; + ptrbb = ptrbb + 1; + } + + res0 = res0 * alpha_ + BF16TOF32(C0[0]); + res1 = res1 * alpha_ + BF16TOF32(C0[1]); + + C0[0] = F32TOBF16(res0); + C0[1] = F32TOBF16(res1); + C0 = C0 + 2; + } + + for (i = 0; i < (bm & 1); i += 1) { + ptrbb = bb; + res0 = 0; + for (k = 0; k < bk; k += 1) { + load0 = BF16TOF32(ptrba[0 + 0]); + load1 = BF16TOF32(ptrbb[0 + 0]); + res0 += load0 * load1; + ptrba = ptrba + 1; + ptrbb = ptrbb + 1; + } + + res0 = res0 * alpha_ + BF16TOF32(C0[0]); + C0[0] = F32TOBF16(res0); + C0 = C0 + 1; + } + + k = (bk << 0); + bb = bb + k; + C = C + ldc; + } + + return 0; +} diff --git a/kernel/generic/gemm_beta.c b/kernel/generic/gemm_beta.c index ccb772cc7..64ff505e8 100644 --- a/kernel/generic/gemm_beta.c +++ b/kernel/generic/gemm_beta.c @@ -1,5 +1,6 @@ /*********************************************************************/ /* Copyright 2009, 2010 The University of Texas at Austin. */ +/* Copyright 2025 The OpenBLAS Project. */ /* All rights reserved. */ /* */ /* Redistribution and use in source and binary forms, with or */ @@ -38,6 +39,39 @@ #include "common.h" +#if (defined(BFLOAT16) || defined(BFLOAT16_ONLY)) && defined(BFLOAT16CONVERSION) +static float +bfloat16tof32 (bfloat16 f16) +{ + float result = 0; + unsigned short* q = (unsigned short*)(&result); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + q[0] = f16; +#else + q[1] = f16; +#endif + return result; +} + +static bfloat16 +f32tobfloat16(float f32) +{ + unsigned short* q = (unsigned short*)(&f32); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + return q[0]; +#else + return q[1]; +#endif +} + +#define BF16TOF32(x) (bfloat16tof32(x)) +#define F32TOBF16(x) (f32tobfloat16(x)) +#else +#define BF16TOF32(x) x +#define F32TOBF16(x) x +#endif + + int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta, IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, FLOAT *c, BLASLONG ldc){ diff --git a/test/compare_sgemm_bgemm.c b/test/compare_sgemm_bgemm.c new file mode 100644 index 000000000..e5a2ba46d --- /dev/null +++ b/test/compare_sgemm_bgemm.c @@ -0,0 +1,138 @@ +/*************************************************************************** +Copyright (c) 2025 The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ +#include "../common.h" +#include +#include + +#include + +#define SGEMM BLASFUNC(sgemm) +#define BGEMM BLASFUNC(bgemm) +#define BGEMM_LARGEST 256 + +void *malloc_safe(size_t size) { + if (size == 0) + return malloc(1); + else + return malloc(size); +} + +bfloat16 convert_to_bf16(float x) { + bfloat16_t src = x; + bfloat16 dst = 0; + memcpy(&dst, &src, sizeof(src)); + return dst; +} + +int main(int argc, char *argv[]) { + blasint m, n, k; + int i, j, l; + blasint x, y; + int ret = 0; + int loop = BGEMM_LARGEST; + char transA = 'N', transB = 'N'; + + float alpha = 1.0, beta = 0.0; + bfloat16 alpha_bf16 = convert_to_bf16(alpha), + beta_bf16 = convert_to_bf16(beta); + + for (x = 1; x <= BGEMM_LARGEST; x++) { + if ((x > 100) && (x != BGEMM_LARGEST)) + continue; + m = k = n = x; + + float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); + float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); + float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); + + bfloat16_t *AA = (bfloat16_t *)malloc_safe(m * k * sizeof(bfloat16)); + bfloat16_t *BB = (bfloat16_t *)malloc_safe(k * n * sizeof(bfloat16)); + bfloat16_t *CC = (bfloat16_t *)malloc_safe(m * n * sizeof(bfloat16)); + + if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || + (BB == NULL) || (CC == NULL)) + return 1; + + for (int i = 0; i < m; i++) { + for (int j = 0; j < k; j++) { + AA[i * k + j] = (i * k + j + 1) % 100; + A[i * k + j] = AA[i * k + j]; + } + } + + for (int i = 0; i < n; i++) { + for (int j = 0; j < k; j++) { + BB[i * k + j] = (i * k + j + 1) % 100; + B[i * k + j] = BB[i * k + j]; + } + } + + for (y = 0; y < 1; y++) { + if ((y == 0) || (y == 2)) { + transA = 'N'; + } else { + transA = 'T'; + } + if ((y == 0) || (y == 1)) { + transB = 'N'; + } else { + transB = 'T'; + } + + memset(C, 0, m * n * sizeof(FLOAT)); + memset(CC, 0, m * n * sizeof(bfloat16)); + SGEMM(&transA, &transB, &m, &n, &k, &alpha_bf16, A, &m, B, &k, &beta_bf16, + C, &m); + BGEMM(&transA, &transB, &m, &n, &k, &alpha, (bfloat16 *)AA, &m, + (bfloat16 *)BB, &k, &beta, (bfloat16 *)CC, &m); + + for (i = 0; i < n; i++) { + for (j = 0; j < m; j++) { + for (l = 0; l < k; l++) { + if (fabs(CC[i * m + j] - C[i * m + j]) > 1.0) { + ret++; + } + } + } + } + } + + free(A); + free(B); + free(C); + free(AA); + free(BB); + free(CC); + } + + if (ret != 0) { + fprintf(stderr, "FATAL ERROR BGEMM - Return code: %d\n", ret); + return ret; + } + + return 0; +} \ No newline at end of file