| @@ -81,6 +81,16 @@ float16to32 (bfloat16_bits f16) | |||||
| return f32.v; | return f32.v; | ||||
| } | } | ||||
| float | |||||
| float32to16 (float32_bits f32) | |||||
| { | |||||
| bfloat16_bits f16; | |||||
| f16.bits.s = f32.bits.s; | |||||
| f16.bits.e = f32.bits.e; | |||||
| f16.bits.m = (uint32_t) f32.bits.m >> 16; | |||||
| return f32.v; | |||||
| } | |||||
| int | int | ||||
| main (int argc, char *argv[]) | main (int argc, char *argv[]) | ||||
| { | { | ||||
| @@ -108,16 +118,16 @@ main (int argc, char *argv[]) | |||||
| A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | ||||
| B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | ||||
| C[j * k + i] = 0; | C[j * k + i] = 0; | ||||
| AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16; | |||||
| BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16; | |||||
| AA[j * k + i].v = float32to16( A[j * k + i] ); | |||||
| BB[j * k + i].v = float32to16( B[j * k + i] ); | |||||
| CC[j * k + i] = 0; | CC[j * k + i] = 0; | ||||
| DD[j * k + i] = 0; | DD[j * k + i] = 0; | ||||
| } | } | ||||
| } | } | ||||
| SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, | SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, | ||||
| &m, B, &k, &beta, C, &m); | &m, B, &k, &beta, C, &m); | ||||
| SBGEMM (&transA, &transB, &m, &n, &k, &alpha, AA, | |||||
| &m, BB, &k, &beta, CC, &m); | |||||
| SBGEMM (&transA, &transB, &m, &n, &k, &alpha, (bfloat16*) AA, | |||||
| &m, (bfloat16*)BB, &k, &beta, CC, &m); | |||||
| for (i = 0; i < n; i++) | for (i = 0; i < n; i++) | ||||
| for (j = 0; j < m; j++) | for (j = 0; j < m; j++) | ||||
| if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0) | if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0) | ||||