|
|
@@ -27,72 +27,15 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
|
|
#include <stdio.h> |
|
|
#include <stdio.h> |
|
|
#include <stdint.h> |
|
|
#include <stdint.h> |
|
|
#include "../common.h" |
|
|
#include "../common.h" |
|
|
|
|
|
|
|
|
|
|
|
#include "test_helpers.h" |
|
|
|
|
|
|
|
|
#define SGEMM BLASFUNC(sgemm) |
|
|
#define SGEMM BLASFUNC(sgemm) |
|
|
#define SBGEMM BLASFUNC(sbgemm) |
|
|
#define SBGEMM BLASFUNC(sbgemm) |
|
|
#define SGEMV BLASFUNC(sgemv) |
|
|
#define SGEMV BLASFUNC(sgemv) |
|
|
#define SBGEMV BLASFUNC(sbgemv) |
|
|
#define SBGEMV BLASFUNC(sbgemv) |
|
|
typedef union |
|
|
|
|
|
{ |
|
|
|
|
|
unsigned short v; |
|
|
|
|
|
#if defined(_AIX) |
|
|
|
|
|
struct __attribute__((packed)) |
|
|
|
|
|
#else |
|
|
|
|
|
struct |
|
|
|
|
|
#endif |
|
|
|
|
|
{ |
|
|
|
|
|
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ |
|
|
|
|
|
unsigned short s:1; |
|
|
|
|
|
unsigned short e:8; |
|
|
|
|
|
unsigned short m:7; |
|
|
|
|
|
#else |
|
|
|
|
|
unsigned short m:7; |
|
|
|
|
|
unsigned short e:8; |
|
|
|
|
|
unsigned short s:1; |
|
|
|
|
|
#endif |
|
|
|
|
|
} bits; |
|
|
|
|
|
} bfloat16_bits; |
|
|
|
|
|
|
|
|
|
|
|
typedef union |
|
|
|
|
|
{ |
|
|
|
|
|
float v; |
|
|
|
|
|
#if defined(_AIX) |
|
|
|
|
|
struct __attribute__((packed)) |
|
|
|
|
|
#else |
|
|
|
|
|
struct |
|
|
|
|
|
#endif |
|
|
|
|
|
{ |
|
|
|
|
|
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ |
|
|
|
|
|
uint32_t s:1; |
|
|
|
|
|
uint32_t e:8; |
|
|
|
|
|
uint32_t m:23; |
|
|
|
|
|
#else |
|
|
|
|
|
uint32_t m:23; |
|
|
|
|
|
uint32_t e:8; |
|
|
|
|
|
uint32_t s:1; |
|
|
|
|
|
#endif |
|
|
|
|
|
} bits; |
|
|
|
|
|
} float32_bits; |
|
|
|
|
|
|
|
|
|
|
|
float |
|
|
|
|
|
float16to32 (bfloat16_bits f16) |
|
|
|
|
|
{ |
|
|
|
|
|
float32_bits f32; |
|
|
|
|
|
f32.bits.s = f16.bits.s; |
|
|
|
|
|
f32.bits.e = f16.bits.e; |
|
|
|
|
|
f32.bits.m = (uint32_t) f16.bits.m << 16; |
|
|
|
|
|
return f32.v; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#define SBGEMM_LARGEST 256 |
|
|
#define SBGEMM_LARGEST 256 |
|
|
|
|
|
|
|
|
void *malloc_safe(size_t size) |
|
|
|
|
|
{ |
|
|
|
|
|
if (size == 0) |
|
|
|
|
|
return malloc(1); |
|
|
|
|
|
else |
|
|
|
|
|
return malloc(size); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
int |
|
|
int |
|
|
main (int argc, char *argv[]) |
|
|
main (int argc, char *argv[]) |
|
|
{ |
|
|
{ |
|
|
@@ -111,14 +54,13 @@ main (int argc, char *argv[]) |
|
|
float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); |
|
|
float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); |
|
|
float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); |
|
|
float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); |
|
|
float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); |
|
|
float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); |
|
|
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(m * k * sizeof(bfloat16_bits)); |
|
|
|
|
|
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits)); |
|
|
|
|
|
|
|
|
bfloat16 *AA = (bfloat16 *)malloc_safe(m * k * sizeof(bfloat16)); |
|
|
|
|
|
bfloat16 *BB = (bfloat16 *)malloc_safe(k * n * sizeof(bfloat16)); |
|
|
float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT)); |
|
|
float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT)); |
|
|
float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT)); |
|
|
float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT)); |
|
|
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || |
|
|
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || |
|
|
(DD == NULL) || (CC == NULL)) |
|
|
(DD == NULL) || (CC == NULL)) |
|
|
return 1; |
|
|
return 1; |
|
|
bfloat16 atmp,btmp; |
|
|
|
|
|
blasint one=1; |
|
|
blasint one=1; |
|
|
|
|
|
|
|
|
for (j = 0; j < m; j++) |
|
|
for (j = 0; j < m; j++) |
|
|
@@ -126,8 +68,7 @@ main (int argc, char *argv[]) |
|
|
for (i = 0; i < k; i++) |
|
|
for (i = 0; i < k; i++) |
|
|
{ |
|
|
{ |
|
|
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one); |
|
|
|
|
|
AA[j * k + i].v = atmp; |
|
|
|
|
|
|
|
|
sbstobf16_(&one, &A[j*k+i], &one, &AA[j * k + i], &one); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
for (j = 0; j < n; j++) |
|
|
for (j = 0; j < n; j++) |
|
|
@@ -135,8 +76,7 @@ main (int argc, char *argv[]) |
|
|
for (i = 0; i < k; i++) |
|
|
for (i = 0; i < k; i++) |
|
|
{ |
|
|
{ |
|
|
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one); |
|
|
|
|
|
BB[j * k + i].v = btmp; |
|
|
|
|
|
|
|
|
sbstobf16_(&one, &B[j*k+i], &one, &BB[j * k + i], &one); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
for (y = 0; y < 4; y++) |
|
|
for (y = 0; y < 4; y++) |
|
|
@@ -182,10 +122,12 @@ main (int argc, char *argv[]) |
|
|
DD[i * m + j] += |
|
|
DD[i * m + j] += |
|
|
float16to32 (AA[k * j + l]) * float16to32 (BB[i + l * n]); |
|
|
float16to32 (AA[k * j + l]) * float16to32 (BB[i + l * n]); |
|
|
} |
|
|
} |
|
|
if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0) |
|
|
|
|
|
|
|
|
if (!is_close(CC[i * m + j], C[i * m + j], 0.01, 0.001)) { |
|
|
ret++; |
|
|
ret++; |
|
|
if (fabs (CC[i * m + j] - DD[i * m + j]) > 1.0) |
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
if (!is_close(CC[i * m + j], DD[i * m + j], 0.001, 0.0001)) { |
|
|
ret++; |
|
|
ret++; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
free(A); |
|
|
free(A); |
|
|
@@ -211,14 +153,13 @@ main (int argc, char *argv[]) |
|
|
float *A = (float *)malloc_safe(x * x * sizeof(FLOAT)); |
|
|
float *A = (float *)malloc_safe(x * x * sizeof(FLOAT)); |
|
|
float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l); |
|
|
float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l); |
|
|
float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l); |
|
|
float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l); |
|
|
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits)); |
|
|
|
|
|
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits) << l); |
|
|
|
|
|
|
|
|
bfloat16 *AA = (bfloat16 *)malloc_safe(x * x * sizeof(bfloat16)); |
|
|
|
|
|
bfloat16 *BB = (bfloat16 *)malloc_safe(x * sizeof(bfloat16) << l); |
|
|
float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); |
|
|
float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); |
|
|
float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l); |
|
|
float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l); |
|
|
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || |
|
|
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || |
|
|
(DD == NULL) || (CC == NULL)) |
|
|
(DD == NULL) || (CC == NULL)) |
|
|
return 1; |
|
|
return 1; |
|
|
bfloat16 atmp, btmp; |
|
|
|
|
|
blasint one = 1; |
|
|
blasint one = 1; |
|
|
|
|
|
|
|
|
for (j = 0; j < x; j++) |
|
|
for (j = 0; j < x; j++) |
|
|
@@ -226,12 +167,10 @@ main (int argc, char *argv[]) |
|
|
for (i = 0; i < x; i++) |
|
|
for (i = 0; i < x; i++) |
|
|
{ |
|
|
{ |
|
|
A[j * x + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
A[j * x + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one); |
|
|
|
|
|
AA[j * x + i].v = atmp; |
|
|
|
|
|
|
|
|
sbstobf16_(&one, &A[j*x+i], &one, &AA[j * x + i], &one); |
|
|
} |
|
|
} |
|
|
B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
sbstobf16_(&one, &B[j << l], &one, &btmp, &one); |
|
|
|
|
|
BB[j << l].v = btmp; |
|
|
|
|
|
|
|
|
sbstobf16_(&one, &B[j << l], &one, &BB[j << l], &one); |
|
|
|
|
|
|
|
|
CC[j << l] = C[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
CC[j << l] = C[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
} |
|
|
} |
|
|
@@ -262,10 +201,12 @@ main (int argc, char *argv[]) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for (j = 0; j < x; j++) { |
|
|
for (j = 0; j < x; j++) { |
|
|
if (fabs (CC[j << l] - C[j << l]) > 1.0) |
|
|
|
|
|
|
|
|
if (!is_close(CC[j << l], C[j << l], 0.01, 0.001)) { |
|
|
ret++; |
|
|
ret++; |
|
|
if (fabs (CC[j << l] - DD[j]) > 1.0) |
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
if (!is_close(CC[j << l], DD[j], 0.001, 0.0001)) { |
|
|
ret++; |
|
|
ret++; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
free(A); |
|
|
free(A); |
|
|
|