/*************************************************************************** * Copyright (c) 2025, The OpenBLAS Project * All rights reserved. * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are * met: * 1. Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in * the documentation and/or other materials provided with the * distribution. * 3. Neither the name of the OpenBLAS project nor the names of * its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. * *****************************************************************************/ #include "common.h" static float bfloat16tof32(bfloat16 f16) { float result = 0; unsigned short *q = (unsigned short *)(&result); #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ q[0] = f16; #else q[1] = f16; #endif return result; } static bfloat16 f32tobfloat16(float f32) { unsigned short *q = (unsigned short *)(&f32); #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ return q[0]; #else return q[1]; #endif } #define BF16TOF32(x) (bfloat16tof32(x)) #define F32TOBF16(x) (f32tobfloat16(x)) int CNAME(BLASLONG bm, BLASLONG bn, BLASLONG bk, FLOAT alpha, IFLOAT *ba, IFLOAT *bb, FLOAT *C, BLASLONG ldc) { BLASLONG i, j, k; FLOAT *C0, *C1; // bfloat16 IFLOAT *ptrba, *ptrbb; float res0, res1, res2, res3; float load0, load1, load2, load3, load4, load5, load6, load7; float alpha_ = BF16TOF32(alpha); for (j = 0; j < bn / 2; j += 1) { C0 = C; C1 = C0 + ldc; ptrba = ba; for (i = 0; i < bm / 2; i += 1) { ptrbb = bb; res0 = 0; res1 = 0; res2 = 0; res3 = 0; for (k = 0; k < bk / 4; k += 1) { load0 = BF16TOF32(ptrba[2 * 0 + 0]); load2 = BF16TOF32(ptrba[2 * 0 + 1]); load4 = BF16TOF32(ptrba[2 * 1 + 0]); load6 = BF16TOF32(ptrba[2 * 1 + 1]); load1 = BF16TOF32(ptrbb[2 * 0 + 0]); load3 = BF16TOF32(ptrbb[2 * 0 + 1]); load5 = BF16TOF32(ptrbb[2 * 1 + 0]); load7 = BF16TOF32(ptrbb[2 * 1 + 1]); res0 = res0 + load0 * load1; res1 = res1 + load2 * load1; res2 = res2 + load0 * load3; res3 = res3 + load2 * load3; res0 = res0 + load4 * load5; res1 = res1 + load6 * load5; res2 = res2 + load4 * load7; res3 = res3 + load6 * load7; load0 = BF16TOF32(ptrba[2 * 2 + 0]); load2 = BF16TOF32(ptrba[2 * 2 + 1]); load4 = BF16TOF32(ptrba[2 * 3 + 0]); load6 = BF16TOF32(ptrba[2 * 3 + 1]); load1 = BF16TOF32(ptrbb[2 * 2 + 0]); load3 = BF16TOF32(ptrbb[2 * 2 + 1]); load5 = BF16TOF32(ptrbb[2 * 3 + 0]); load7 = BF16TOF32(ptrbb[2 * 3 + 1]); res0 = res0 + load0 * load1; res1 = res1 + load2 * load1; res2 = res2 + load0 * load3; res3 = res3 + load2 * load3; res0 = res0 + load4 * load5; res1 = res1 + load6 * load5; res2 = res2 + load4 * load7; res3 = res3 + load6 * load7; } for (k = 0; k < (bk & 3); k += 1) { load0 = BF16TOF32(ptrba[2 * 0 + 0]); load2 = BF16TOF32(ptrba[2 * 0 + 1]); load1 = BF16TOF32(ptrbb[2 * 0 + 0]); load3 = BF16TOF32(ptrbb[2 * 0 + 1]); res0 = res0 + load0 * load1; res1 = res1 + load2 * load1; res2 = res2 + load0 * load3; res3 = res3 + load2 * load3; ptrba = ptrba + 2; ptrbb = ptrbb + 2; } res0 = res0 * alpha_ + BF16TOF32(C0[0]); res1 = res1 * alpha_ + BF16TOF32(C0[1]); res2 = res2 * alpha_ + BF16TOF32(C1[0]); res3 = res3 * alpha_ + BF16TOF32(C1[1]); C0[0] = F32TOBF16(res0); C0[1] = F32TOBF16(res1); C1[0] = F32TOBF16(res2); C1[1] = F32TOBF16(res3); C0 = C0 + 2; C1 = C1 + 2; } for (i = 0; i < (bm & 1); i += 1) { ptrbb = bb; res0 = 0; res1 = 0; for (k = 0; k < bk; k += 1) { load0 = BF16TOF32(ptrba[0 + 0]); load1 = BF16TOF32(ptrbb[2 * 0 + 0]); load2 = BF16TOF32(ptrbb[2 * 0 + 1]); res0 = res0 + load0 * load1; res1 = res1 + load0 * load2; ptrba = ptrba + 1; ptrbb = ptrbb + 2; } res0 = res0 * alpha_ + BF16TOF32(C0[0]); res1 = res1 * alpha_ + BF16TOF32(C1[0]); C0[0] = res0; C1[0] = res1; C0 = C0 + 1; C1 = C1 + 1; } k = (bk << 1); bb = bb + k; i = (ldc << 1); C = C + i; } for (j = 0; j < (bn & 1); j += 1) { C0 = C; ptrba = ba; for (i = 0; i < bm / 2; i += 1) { ptrbb = bb; res0 = 0; res1 = 0; for (k = 0; k < bk; k += 1) { load0 = BF16TOF32(ptrba[2 * 0 + 0]); load2 = BF16TOF32(ptrba[2 * 0 + 1]); load1 = BF16TOF32(ptrbb[0 + 0]); res0 = res0 + load0 * load1; res1 = res1 + load2 * load1; ptrba = ptrba + 2; ptrbb = ptrbb + 1; } res0 = res0 * alpha_ + BF16TOF32(C0[0]); res1 = res1 * alpha_ + BF16TOF32(C0[1]); C0[0] = F32TOBF16(res0); C0[1] = F32TOBF16(res1); C0 = C0 + 2; } for (i = 0; i < (bm & 1); i += 1) { ptrbb = bb; res0 = 0; for (k = 0; k < bk; k += 1) { load0 = BF16TOF32(ptrba[0 + 0]); load1 = BF16TOF32(ptrbb[0 + 0]); res0 += load0 * load1; ptrba = ptrba + 1; ptrbb = ptrbb + 1; } res0 = res0 * alpha_ + BF16TOF32(C0[0]); C0[0] = F32TOBF16(res0); C0 = C0 + 1; } k = (bk << 0); bb = bb + k; C = C + ldc; } return 0; }