| @@ -85,6 +85,14 @@ float16to32 (bfloat16_bits f16) | |||
| #define SBGEMM_LARGEST 256 | |||
| void *malloc_safe(size_t size) | |||
| { | |||
| if (size == 0) | |||
| return malloc(1); | |||
| else | |||
| return malloc(size); | |||
| } | |||
| int | |||
| main (int argc, char *argv[]) | |||
| { | |||
| @@ -96,17 +104,17 @@ main (int argc, char *argv[]) | |||
| char transA = 'N', transB = 'N'; | |||
| float alpha = 1.0, beta = 0.0; | |||
| for (x = 1; x <= loop; x++) | |||
| for (x = 0; x <= loop; x++) | |||
| { | |||
| if ((x > 100) && (x != SBGEMM_LARGEST)) continue; | |||
| m = k = n = x; | |||
| float *A = (float *)malloc(m * k * sizeof(FLOAT)); | |||
| float *B = (float *)malloc(k * n * sizeof(FLOAT)); | |||
| float *C = (float *)malloc(m * n * sizeof(FLOAT)); | |||
| bfloat16_bits *AA = (bfloat16_bits *)malloc(m * k * sizeof(bfloat16_bits)); | |||
| bfloat16_bits *BB = (bfloat16_bits *)malloc(k * n * sizeof(bfloat16_bits)); | |||
| float *DD = (float *)malloc(m * n * sizeof(FLOAT)); | |||
| float *CC = (float *)malloc(m * n * sizeof(FLOAT)); | |||
| float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); | |||
| float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); | |||
| float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); | |||
| bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(m * k * sizeof(bfloat16_bits)); | |||
| bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits)); | |||
| float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT)); | |||
| float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT)); | |||
| if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || | |||
| (DD == NULL) || (CC == NULL)) | |||
| return 1; | |||
| @@ -195,15 +203,15 @@ main (int argc, char *argv[]) | |||
| } | |||
| k = 1; | |||
| for (x = 1; x <= loop; x++) | |||
| for (x = 0; x <= loop; x++) | |||
| { | |||
| float *A = (float *)malloc(x * x * sizeof(FLOAT)); | |||
| float *B = (float *)malloc(x * sizeof(FLOAT)); | |||
| float *C = (float *)malloc(x * sizeof(FLOAT)); | |||
| bfloat16_bits *AA = (bfloat16_bits *)malloc(x * x * sizeof(bfloat16_bits)); | |||
| bfloat16_bits *BB = (bfloat16_bits *)malloc(x * sizeof(bfloat16_bits)); | |||
| float *DD = (float *)malloc(x * sizeof(FLOAT)); | |||
| float *CC = (float *)malloc(x * sizeof(FLOAT)); | |||
| float *A = (float *)malloc_safe(x * x * sizeof(FLOAT)); | |||
| float *B = (float *)malloc_safe(x * sizeof(FLOAT)); | |||
| float *C = (float *)malloc_safe(x * sizeof(FLOAT)); | |||
| bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits)); | |||
| bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits)); | |||
| float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); | |||
| float *CC = (float *)malloc_safe(x * sizeof(FLOAT)); | |||
| if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || | |||
| (DD == NULL) || (CC == NULL)) | |||
| return 1; | |||