|
|
|
@@ -205,15 +205,14 @@ main (int argc, char *argv[]) |
|
|
|
for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one. |
|
|
|
for (x = 1; x <= loop; x++) |
|
|
|
{ |
|
|
|
m = l + 1; |
|
|
|
k = (x == 0) ? 0 : m; |
|
|
|
k = (x == 0) ? 0 : l + 1; |
|
|
|
float *A = (float *)malloc_safe(x * x * sizeof(FLOAT)); |
|
|
|
float *B = (float *)malloc_safe(x * sizeof(FLOAT) * m); |
|
|
|
float *C = (float *)malloc_safe(x * sizeof(FLOAT) * m); |
|
|
|
float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l); |
|
|
|
float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l); |
|
|
|
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits)); |
|
|
|
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits) * m); |
|
|
|
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits) << l); |
|
|
|
float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); |
|
|
|
float *CC = (float *)malloc_safe(x * sizeof(FLOAT) * m); |
|
|
|
float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l); |
|
|
|
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || |
|
|
|
(DD == NULL) || (CC == NULL)) |
|
|
|
return 1; |
|
|
|
@@ -228,9 +227,9 @@ main (int argc, char *argv[]) |
|
|
|
sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one); |
|
|
|
AA[j * x + i].v = atmp; |
|
|
|
} |
|
|
|
B[j*m] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
|
sbstobf16_(&one, &B[j*m], &one, &btmp, &one); |
|
|
|
BB[j*m].v = btmp; |
|
|
|
B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
|
sbstobf16_(&one, &B[j << l], &one, &btmp, &one); |
|
|
|
BB[j << l].v = btmp; |
|
|
|
} |
|
|
|
for (y = 0; y < 2; y++) |
|
|
|
{ |
|
|
|
@@ -240,9 +239,9 @@ main (int argc, char *argv[]) |
|
|
|
transA = 'T'; |
|
|
|
} |
|
|
|
|
|
|
|
memset(CC, 0, x * m * sizeof(FLOAT)); |
|
|
|
memset(CC, 0, x * sizeof(FLOAT) << l); |
|
|
|
memset(DD, 0, x * sizeof(FLOAT)); |
|
|
|
memset(C, 0, x * m * sizeof(FLOAT)); |
|
|
|
memset(C, 0, x * sizeof(FLOAT) << l); |
|
|
|
|
|
|
|
SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k); |
|
|
|
SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k); |
|
|
|
@@ -250,15 +249,15 @@ main (int argc, char *argv[]) |
|
|
|
for (j = 0; j < x; j++) |
|
|
|
for (i = 0; i < x; i++) |
|
|
|
if (transA == 'N') { |
|
|
|
DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j*m]); |
|
|
|
DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j << l]); |
|
|
|
} else if (transA == 'T') { |
|
|
|
DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i*m]); |
|
|
|
DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i << l]); |
|
|
|
} |
|
|
|
|
|
|
|
for (j = 0; j < x; j++) { |
|
|
|
if (fabs (CC[j*m] - C[j*m]) > 1.0) |
|
|
|
if (fabs (CC[j << l] - C[j << l]) > 1.0) |
|
|
|
ret++; |
|
|
|
if (fabs (CC[j*m] - DD[j]) > 1.0) |
|
|
|
if (fabs (CC[j << l] - DD[j]) > 1.0) |
|
|
|
ret++; |
|
|
|
} |
|
|
|
} |
|
|
|
|