Browse Source

Merge pull request #4488 from martin-frbg/issue4475-2

Separate the interface for SBGEMMT from GEMMT
tags/v0.3.27
Martin Kroeker GitHub 2 years ago
parent
commit
22b487b622
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
5 changed files with 466 additions and 11 deletions
  1. +1
    -0
      interface/CMakeLists.txt
  2. +2
    -6
      interface/Makefile
  3. +2
    -1
      interface/gemmt.c
  4. +447
    -0
      interface/sbgemmt.c
  5. +14
    -4
      test/compare_sgemm_sbgemm.c

+ 1
- 0
interface/CMakeLists.txt View File

@@ -119,6 +119,7 @@ endif ()
if (BUILD_BFLOAT16)
GenerateNamedObjects("bf16dot.c" "" "sbdot" ${CBLAS_FLAG} "" "" true "BFLOAT16")
GenerateNamedObjects("gemm.c" "" "sbgemm" ${CBLAS_FLAG} "" "" true "BFLOAT16")
GenerateNamedObjects("gemmt.c" "" "sbgemmt" ${CBLAS_FLAG} "" "" true "BFLOAT16")
GenerateNamedObjects("sbgemv.c" "" "sbgemv" ${CBLAS_FLAG} "" "" true "BFLOAT16")
GenerateNamedObjects("tobf16.c" "SINGLE_PREC" "sbstobf16" ${CBLAS_FLAG} "" "" true "BFLOAT16")
GenerateNamedObjects("tobf16.c" "DOUBLE_PREC" "sbdtobf16" ${CBLAS_FLAG} "" "" true "BFLOAT16")


+ 2
- 6
interface/Makefile View File

@@ -1303,7 +1303,7 @@ xhpr2.$(SUFFIX) xhpr2.$(PSUFFIX) : zhpr2.c
ifeq ($(BUILD_BFLOAT16),1)
sbgemm.$(SUFFIX) sbgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -c $(CFLAGS) $< -o $(@F)
sbgemmt.$(SUFFIX) sbgemmt.$(PSUFFIX) : gemmt.c ../param.h
sbgemmt.$(SUFFIX) sbgemmt.$(PSUFFIX) : sbgemmt.c ../param.h
$(CC) -c $(CFLAGS) $< -o $(@F)
endif

@@ -1662,10 +1662,6 @@ cblas_zaxpyc.$(SUFFIX) cblas_zaxpyc.$(PSUFFIX) : zaxpy.c
cblas_xaxpyc.$(SUFFIX) cblas_xaxpyc.$(PSUFFIX) : zaxpy.c
$(CC) $(CFLAGS) -DCBLAS -c -DCONJ $< -o $(@F)

sscal.$(SUFFIX) sscal.$(PSUFFIX) : scal.c
$(CC) $(CFLAGS) -c $< -o $(@F)

dscal.$(SUFFIX) dscal.$(PSUFFIX) : scal.c
cblas_zaxpy.$(SUFFIX) cblas_zaxpy.$(PSUFFIX) : zaxpy.c
$(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F)

@@ -1971,7 +1967,7 @@ cblas_sgemmt.$(SUFFIX) cblas_sgemmt.$(PSUFFIX) : gemmt.c ../param.h
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)

ifeq ($(BUILD_BFLOAT16),1)
cblas_sbgemmt.$(SUFFIX) cblas_sbgemmt.$(PSUFFIX) : gemmt.c ../param.h
cblas_sbgemmt.$(SUFFIX) cblas_sbgemmt.$(PSUFFIX) : sbgemmt.c ../param.h
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
endif



+ 2
- 1
interface/gemmt.c View File

@@ -158,7 +158,8 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB,
uplo = 0;
if (Uplo == 'L')
uplo = 1;

nrowa = m;
if (transa & 1) nrowa = k;
nrowb = k;
#if defined(COMPLEX)


+ 447
- 0
interface/sbgemmt.c View File

@@ -0,0 +1,447 @@
/*********************************************************************/
/* Copyright 2024, The OpenBLAS Project. */
/* All rights reserved. */
/* */
/* Redistribution and use in source and binary forms, with or */
/* without modification, are permitted provided that the following */
/* conditions are met: */
/* */
/* 1. Redistributions of source code must retain the above */
/* copyright notice, this list of conditions and the following */
/* disclaimer. */
/* */
/* 2. Redistributions in binary form must reproduce the above */
/* copyright notice, this list of conditions and the following */
/* disclaimer in the documentation and/or other materials */
/* provided with the distribution. */
/* */
/* THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF TEXAS AT */
/* AUSTIN ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, */
/* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF */
/* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE */
/* DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT */
/* AUSTIN OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, */
/* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES */
/* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE */
/* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR */
/* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF */
/* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT */
/* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT */
/* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE */
/* POSSIBILITY OF SUCH DAMAGE. */
/* */
/*********************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include "common.h"

#define SMP_THRESHOLD_MIN 65536.0
#define ERROR_NAME "SBGEMMT "

#ifndef GEMM_MULTITHREAD_THRESHOLD
#define GEMM_MULTITHREAD_THRESHOLD 4
#endif

#ifndef CBLAS

void NAME(char *UPLO, char *TRANSA, char *TRANSB,
blasint * M, blasint * K,
FLOAT * Alpha,
IFLOAT * a, blasint * ldA,
IFLOAT * b, blasint * ldB, FLOAT * Beta, FLOAT * c, blasint * ldC)
{

blasint m, k;
blasint lda, ldb, ldc;
int transa, transb, uplo;
blasint info;

char transA, transB, Uplo;
blasint nrowa, nrowb;
IFLOAT *buffer;
IFLOAT *aa, *bb;
FLOAT *cc;
FLOAT alpha, beta;

PRINT_DEBUG_NAME;

m = *M;
k = *K;

alpha = *Alpha;
beta = *Beta;

lda = *ldA;
ldb = *ldB;
ldc = *ldC;

transA = *TRANSA;
transB = *TRANSB;
Uplo = *UPLO;
TOUPPER(transA);
TOUPPER(transB);
TOUPPER(Uplo);

transa = -1;
transb = -1;
uplo = -1;

if (transA == 'N')
transa = 0;
if (transA == 'T')
transa = 1;

if (transA == 'R')
transa = 0;
if (transA == 'C')
transa = 1;

if (transB == 'N')
transb = 0;
if (transB == 'T')
transb = 1;

if (transB == 'R')
transb = 0;
if (transB == 'C')
transb = 1;

if (Uplo == 'U')
uplo = 0;
if (Uplo == 'L')
uplo = 1;
nrowa = m;
if (transa & 1) nrowa = k;
nrowb = k;
if (transb & 1) nrowb = m;

info = 0;

if (ldc < MAX(1, m))
info = 13;
if (ldb < MAX(1, nrowb))
info = 10;
if (lda < MAX(1, nrowa))
info = 8;
if (k < 0)
info = 5;
if (m < 0)
info = 4;
if (transb < 0)
info = 3;
if (transa < 0)
info = 2;
if (uplo < 0)
info = 1;

if (info != 0) {
BLASFUNC(xerbla) (ERROR_NAME, &info, sizeof(ERROR_NAME));
return;
}
#else

void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, blasint m,
blasint k,
FLOAT alpha,
IFLOAT * A, blasint LDA,
IFLOAT * B, blasint LDB, FLOAT beta, FLOAT * c, blasint ldc)
{
IFLOAT *aa, *bb;
FLOAT *cc;

int transa, transb, uplo;
blasint info;
blasint lda, ldb;
IFLOAT *a, *b;
XFLOAT *buffer;

PRINT_DEBUG_CNAME;

uplo = -1;
transa = -1;
transb = -1;
info = 0;

if (order == CblasColMajor) {
if (Uplo == CblasUpper) uplo = 0;
if (Uplo == CblasLower) uplo = 1;

if (TransA == CblasNoTrans)
transa = 0;
if (TransA == CblasTrans)
transa = 1;

if (TransA == CblasConjNoTrans)
transa = 0;
if (TransA == CblasConjTrans)
transa = 1;

if (TransB == CblasNoTrans)
transb = 0;
if (TransB == CblasTrans)
transb = 1;

if (TransB == CblasConjNoTrans)
transb = 0;
if (TransB == CblasConjTrans)
transb = 1;

a = (void *)A;
b = (void *)B;
lda = LDA;
ldb = LDB;

info = -1;

blasint nrowa;
blasint nrowb;
nrowa = m;
if (transa & 1) nrowa = k;
nrowb = k;
if (transb & 1) nrowb = m;

if (ldc < MAX(1, m))
info = 13;
if (ldb < MAX(1, nrowb))
info = 10;
if (lda < MAX(1, nrowa))
info = 8;
if (k < 0)
info = 5;
if (m < 0)
info = 4;
if (transb < 0)
info = 3;
if (transa < 0)
info = 2;
if (uplo < 0)
info = 1;
}

if (order == CblasRowMajor) {

a = (void *)B;
b = (void *)A;

lda = LDB;
ldb = LDA;

if (Uplo == CblasUpper) uplo = 0;
if (Uplo == CblasLower) uplo = 1;

if (TransB == CblasNoTrans)
transa = 0;
if (TransB == CblasTrans)
transa = 1;

if (TransB == CblasConjNoTrans)
transa = 0;
if (TransB == CblasConjTrans)
transa = 1;

if (TransA == CblasNoTrans)
transb = 0;
if (TransA == CblasTrans)
transb = 1;

if (TransA == CblasConjNoTrans)
transb = 0;
if (TransA == CblasConjTrans)
transb = 1;

info = -1;

blasint ncola;
blasint ncolb;

ncola = m;
if (transa & 1) ncola = k;
ncolb = k;

if (transb & 1) {
ncolb = m;
}

if (ldc < MAX(1,m))
info = 13;
if (ldb < MAX(1, ncolb))
info = 8;
if (lda < MAX(1, ncola))
info = 10;
if (k < 0)
info = 5;
if (m < 0)
info = 4;
if (transb < 0)
info = 2;
if (transa < 0)
info = 3;
if (uplo < 0)
info = 1;
}

if (info >= 0) {
BLASFUNC(xerbla) (ERROR_NAME, &info, sizeof(ERROR_NAME));
return;
}

#endif
int buffer_size;
blasint i, j;

#ifdef SMP
int nthreads;
#endif


#ifdef SMP
static int (*gemv_thread[]) (BLASLONG, BLASLONG, FLOAT, IFLOAT *,
BLASLONG, IFLOAT *, BLASLONG, FLOAT,
FLOAT *, BLASLONG, int) = {
sbgemv_thread_n, sbgemv_thread_t,
};
#endif
int (*gemv[]) (BLASLONG, BLASLONG, FLOAT, IFLOAT *, BLASLONG,
IFLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG) = {
SBGEMV_N, SBGEMV_T,};


if (m == 0)
return;

IDEBUG_START;

const blasint incb = ((transb & 1) == 0) ? 1 : ldb;

if (uplo == 1) {
for (i = 0; i < m; i++) {
j = m - i;

aa = a + i;
bb = b + i * ldb;
if (transa & 1) {
aa = a + lda * i;
}
if (transb & 1)
bb = b + i;
cc = c + i * ldc + i;

#if 0
if (beta != ONE)
SCAL_K(l, 0, 0, beta, cc, 1, NULL, 0, NULL, 0);

if (alpha == ZERO)
continue;
#endif

IDEBUG_START;

buffer_size = j + k + 128 / sizeof(FLOAT);
#ifdef WINDOWS_ABI
buffer_size += 160 / sizeof(FLOAT);
#endif
// for alignment
buffer_size = (buffer_size + 3) & ~3;
STACK_ALLOC(buffer_size, IFLOAT, buffer);

#ifdef SMP

if (1L * j * k < 2304L * GEMM_MULTITHREAD_THRESHOLD)
nthreads = 1;
else
nthreads = num_cpu_avail(2);

if (nthreads == 1) {
#endif

if (!(transa & 1))
(gemv[(int)transa]) (j, k, alpha, aa, lda,
bb, incb, beta, cc, 1);
else
(gemv[(int)transa]) (k, j, alpha, aa, lda,
bb, incb, beta, cc, 1);

#ifdef SMP
} else {
if (!(transa & 1))
(gemv_thread[(int)transa]) (j, k, alpha, aa,
lda, bb, incb, beta, cc,
1, nthreads);
else
(gemv_thread[(int)transa]) (k, j, alpha, aa,
lda, bb, incb, beta, cc,
1, nthreads);

}
#endif

STACK_FREE(buffer);
}
} else {

for (i = 0; i < m; i++) {
j = i + 1;

bb = b + i * ldb;
if (transb & 1) {
bb = b + i;
}
cc = c + i * ldc;

#if 0
if (beta != ONE)
SCAL_K(l, 0, 0, beta, cc, 1, NULL, 0, NULL, 0);

if (alpha == ZERO)
continue;
#endif
IDEBUG_START;

buffer_size = j + k + 128 / sizeof(FLOAT);
#ifdef WINDOWS_ABI
buffer_size += 160 / sizeof(FLOAT);
#endif
// for alignment
buffer_size = (buffer_size + 3) & ~3;
STACK_ALLOC(buffer_size, IFLOAT, buffer);

#ifdef SMP

if (1L * j * k < 2304L * GEMM_MULTITHREAD_THRESHOLD)
nthreads = 1;
else
nthreads = num_cpu_avail(2);

if (nthreads == 1) {
#endif

if (!(transa & 1))
(gemv[(int)transa]) (j, k, alpha, a, lda, bb,
incb, beta, cc, 1);
else
(gemv[(int)transa]) (k, j, alpha, a, lda, bb,
incb, beta, cc, 1);

#ifdef SMP
} else {
if (!(transa & 1))
(gemv_thread[(int)transa]) (j, k, alpha, a, lda,
bb, incb, beta, cc, 1,
nthreads);
else
(gemv_thread[(int)transa]) (k, j, alpha, a, lda,
bb, incb, beta, cc, 1,
nthreads);
}
#endif

STACK_FREE(buffer);
}
}

IDEBUG_END;

return;
}

+ 14
- 4
test/compare_sgemm_sbgemm.c View File

@@ -81,6 +81,16 @@ float16to32 (bfloat16_bits f16)
return f32.v;
}

float
float32to16 (float32_bits f32)
{
bfloat16_bits f16;
f16.bits.s = f32.bits.s;
f16.bits.e = f32.bits.e;
f16.bits.m = (uint32_t) f32.bits.m >> 16;
return f32.v;
}

int
main (int argc, char *argv[])
{
@@ -108,16 +118,16 @@ main (int argc, char *argv[])
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
C[j * k + i] = 0;
AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16;
BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16;
AA[j * k + i].v = float32to16( A[j * k + i] );
BB[j * k + i].v = float32to16( B[j * k + i] );
CC[j * k + i] = 0;
DD[j * k + i] = 0;
}
}
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
&m, B, &k, &beta, C, &m);
SBGEMM (&transA, &transB, &m, &n, &k, &alpha, AA,
&m, BB, &k, &beta, CC, &m);
SBGEMM (&transA, &transB, &m, &n, &k, &alpha, (bfloat16*) AA,
&m, (bfloat16*)BB, &k, &beta, CC, &m);
for (i = 0; i < n; i++)
for (j = 0; j < m; j++)
if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0)


Loading…
Cancel
Save