| @@ -29,6 +29,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| #include "../common.h" | |||
| #define SGEMM BLASFUNC(sgemm) | |||
| #define SBGEMM BLASFUNC(sbgemm) | |||
| #define SGEMV BLASFUNC(sgemv) | |||
| #define SBGEMV BLASFUNC(sbgemv) | |||
| typedef union | |||
| { | |||
| unsigned short v; | |||
| @@ -187,7 +189,79 @@ main (int argc, char *argv[]) | |||
| free(CC); | |||
| } | |||
| if (ret != 0) | |||
| if (ret != 0) { | |||
| fprintf (stderr, "FATAL ERROR SBGEMM - Return code: %d\n", ret); | |||
| return ret; | |||
| } | |||
| k = 1; | |||
| for (x = 1; x <= loop; x++) | |||
| { | |||
| float *A = (float *)malloc(x * x * sizeof(FLOAT)); | |||
| float *B = (float *)malloc(x * sizeof(FLOAT)); | |||
| float *C = (float *)malloc(x * sizeof(FLOAT)); | |||
| bfloat16_bits *AA = (bfloat16_bits *)malloc(x * x * sizeof(bfloat16_bits)); | |||
| bfloat16_bits *BB = (bfloat16_bits *)malloc(x * sizeof(bfloat16_bits)); | |||
| float *DD = (float *)malloc(x * sizeof(FLOAT)); | |||
| float *CC = (float *)malloc(x * sizeof(FLOAT)); | |||
| if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || | |||
| (DD == NULL) || (CC == NULL)) | |||
| return 1; | |||
| bfloat16 atmp, btmp; | |||
| blasint one = 1; | |||
| for (j = 0; j < x; j++) | |||
| { | |||
| for (i = 0; i < x; i++) | |||
| { | |||
| A[j * x + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | |||
| sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one); | |||
| AA[j * x + i].v = atmp; | |||
| } | |||
| B[j] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | |||
| sbstobf16_(&one, &B[j], &one, &btmp, &one); | |||
| BB[j].v = btmp; | |||
| } | |||
| for (y = 0; y < 2; y++) | |||
| { | |||
| if (y == 0) { | |||
| transA = 'N'; | |||
| } else { | |||
| transA = 'T'; | |||
| } | |||
| memset(CC, 0, x * sizeof(FLOAT)); | |||
| memset(DD, 0, x * sizeof(FLOAT)); | |||
| memset(C, 0, x * sizeof(FLOAT)); | |||
| SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k); | |||
| SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k); | |||
| for (j = 0; j < x; j++) | |||
| for (i = 0; i < x; i++) | |||
| if (transA == 'N') { | |||
| DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j]); | |||
| } else if (transA == 'T') { | |||
| DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i]); | |||
| } | |||
| for (j = 0; j < x; j++) { | |||
| if (fabs (CC[j] - C[j]) > 1.0) | |||
| ret++; | |||
| if (fabs (CC[j] - DD[j]) > 1.0) | |||
| ret++; | |||
| } | |||
| } | |||
| free(A); | |||
| free(B); | |||
| free(C); | |||
| free(AA); | |||
| free(BB); | |||
| free(DD); | |||
| free(CC); | |||
| } | |||
| if (ret != 0) | |||
| fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret); | |||
| return ret; | |||
| } | |||