| @@ -86,14 +86,26 @@ main (int argc, char *argv[]) | |||
| { | |||
| blasint m, n, k; | |||
| int i, j, l; | |||
| blasint x; | |||
| blasint x, y; | |||
| int ret = 0; | |||
| int loop = 100; | |||
| char transA = 'N', transB = 'N'; | |||
| float alpha = 1.0, beta = 0.0; | |||
| for (x = 0; x <= loop; x++) | |||
| { | |||
| for (y = 0; y < 4; y++) | |||
| { | |||
| if ((y == 0) || (y == 2)) { | |||
| transA = 'N'; | |||
| } else { | |||
| transA = 'T'; | |||
| } | |||
| if ((y == 0) || (y == 1)) { | |||
| transB = 'N'; | |||
| } else { | |||
| transB = 'T'; | |||
| } | |||
| m = k = n = x; | |||
| float A[m * k]; | |||
| float B[k * n]; | |||
| @@ -104,43 +116,55 @@ main (int argc, char *argv[]) | |||
| blasint one=1; | |||
| for (j = 0; j < m; j++) | |||
| { | |||
| for (i = 0; i < m; i++) | |||
| { | |||
| A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | |||
| B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | |||
| C[j * k + i] = 0; | |||
| sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one); | |||
| sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one); | |||
| AA[j * k + i].v = atmp; | |||
| BB[j * k + i].v = btmp; | |||
| CC[j * k + i] = 0; | |||
| DD[j * k + i] = 0; | |||
| } | |||
| } | |||
| { | |||
| for (i = 0; i < m; i++) | |||
| { | |||
| A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | |||
| B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | |||
| C[j * k + i] = 0; | |||
| sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one); | |||
| sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one); | |||
| AA[j * k + i].v = atmp; | |||
| BB[j * k + i].v = btmp; | |||
| CC[j * k + i] = 0; | |||
| DD[j * k + i] = 0; | |||
| } | |||
| } | |||
| SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, | |||
| &m, B, &k, &beta, C, &m); | |||
| &m, B, &k, &beta, C, &m); | |||
| SBGEMM (&transA, &transB, &m, &n, &k, &alpha, (bfloat16*) AA, | |||
| &m, (bfloat16*)BB, &k, &beta, CC, &m); | |||
| &m, (bfloat16*)BB, &k, &beta, CC, &m); | |||
| for (i = 0; i < n; i++) | |||
| for (j = 0; j < m; j++) | |||
| if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0) | |||
| ret++; | |||
| for (i = 0; i < n; i++) | |||
| for (j = 0; j < m; j++) | |||
| if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0) | |||
| ret++; | |||
| if (transA == 'N' && transB == 'N') | |||
| { | |||
| for (i = 0; i < n; i++) | |||
| for (j = 0; j < m; j++) | |||
| for (l = 0; l < k; l++) | |||
| { | |||
| DD[i * m + j] += | |||
| float16to32 (AA[l * m + j]) * float16to32 (BB[l + k * i]); | |||
| } | |||
| for (i = 0; i < n; i++) | |||
| for (j = 0; j < m; j++) | |||
| if (CC[i * m + j] != DD[i * m + j]) | |||
| ret++; | |||
| } | |||
| for (j = 0; j < m; j++) | |||
| for (l = 0; l < k; l++) | |||
| if (transA == 'N' && transB == 'N') | |||
| { | |||
| DD[i * m + j] += | |||
| float16to32 (AA[l * m + j]) * float16to32 (BB[l + k * i]); | |||
| } else if (transA == 'T' && transB == 'N') | |||
| { | |||
| DD[i * m + j] += | |||
| float16to32 (AA[k * j + l]) * float16to32 (BB[l + k * i]); | |||
| } else if (transA == 'N' && transB == 'T') | |||
| { | |||
| DD[i * m + j] += | |||
| float16to32 (AA[l * m + j]) * float16to32 (BB[i + l * n]); | |||
| } else if (transA == 'T' && transB == 'T') | |||
| { | |||
| DD[i * m + j] += | |||
| float16to32 (AA[k * j + l]) * float16to32 (BB[i + l * n]); | |||
| } | |||
| for (i = 0; i < n; i++) | |||
| for (j = 0; j < m; j++) | |||
| if (CC[i * m + j] != DD[i * m + j]) | |||
| ret++; | |||
| } | |||
| } | |||
| if (ret != 0) | |||
| fprintf (stderr, "FATAL ERROR SBGEMM - Return code: %d\n", ret); | |||
| return ret; | |||