|
|
|
@@ -16,202 +16,213 @@ |
|
|
|
|
|
|
|
#include "nnacl/fp16/matmul_fp16.h" |
|
|
|
|
|
|
|
void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16) { |
|
|
|
static void Col2Row8SrcFromFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { |
|
|
|
int row_c8 = row / C8NUM * C8NUM; |
|
|
|
int col_c8 = col / C8NUM * C8NUM; |
|
|
|
const float16_t *src = (const float16_t *)src_ptr; |
|
|
|
int ci = 0; |
|
|
|
if (src_float16) { |
|
|
|
const float16_t *src = (const float16_t *)src_ptr; |
|
|
|
for (; ci < col_c8; ci += C8NUM) { |
|
|
|
int ri = 0; |
|
|
|
for (; ri < row_c8; ri += C8NUM) { |
|
|
|
const float16_t *src_ptr1 = src + ci * row + ri; |
|
|
|
float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM; |
|
|
|
for (; ci < col_c8; ci += C8NUM) { |
|
|
|
int ri = 0; |
|
|
|
for (; ri < row_c8; ri += C8NUM) { |
|
|
|
const float16_t *src_ptr1 = src + ci * row + ri; |
|
|
|
float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM; |
|
|
|
#ifdef ENABLE_ARM64 |
|
|
|
size_t strid_row = row * 2; |
|
|
|
asm volatile( |
|
|
|
"mov x10, %[src_ptr1]\n" |
|
|
|
"mov x11, %[dst_ptr1]\n" |
|
|
|
"mov x12, %[strid_row]\n" |
|
|
|
"ld1 {v0.8h}, [x10], x12\n" |
|
|
|
"ld1 {v1.8h}, [x10], x12\n" |
|
|
|
"ld1 {v2.8h}, [x10], x12\n" |
|
|
|
"ld1 {v3.8h}, [x10], x12\n" |
|
|
|
"ld1 {v4.8h}, [x10], x12\n" |
|
|
|
"ld1 {v5.8h}, [x10], x12\n" |
|
|
|
"ld1 {v6.8h}, [x10], x12\n" |
|
|
|
"ld1 {v7.8h}, [x10], x12\n" |
|
|
|
|
|
|
|
"zip1 v8.8h, v0.8h, v1.8h\n" |
|
|
|
"zip1 v9.8h, v2.8h, v3.8h\n" |
|
|
|
"zip1 v10.8h, v4.8h, v5.8h\n" |
|
|
|
"zip1 v11.8h, v6.8h, v7.8h\n" |
|
|
|
|
|
|
|
"trn1 v12.4s, v8.4s, v9.4s\n" |
|
|
|
"trn1 v14.4s, v10.4s, v11.4s\n" |
|
|
|
"trn2 v13.4s, v8.4s, v9.4s\n" |
|
|
|
"trn2 v15.4s, v10.4s, v11.4s\n" |
|
|
|
|
|
|
|
"trn1 v16.2d, v12.2d, v14.2d\n" |
|
|
|
"trn2 v18.2d, v12.2d, v14.2d\n" |
|
|
|
"trn1 v17.2d, v13.2d, v15.2d\n" |
|
|
|
"trn2 v19.2d, v13.2d, v15.2d\n" |
|
|
|
|
|
|
|
"zip2 v8.8h, v0.8h, v1.8h\n" |
|
|
|
"zip2 v9.8h, v2.8h, v3.8h\n" |
|
|
|
"zip2 v10.8h, v4.8h, v5.8h\n" |
|
|
|
"zip2 v11.8h, v6.8h, v7.8h\n" |
|
|
|
|
|
|
|
"trn1 v12.4s, v8.4s, v9.4s\n" |
|
|
|
"trn1 v14.4s, v10.4s, v11.4s\n" |
|
|
|
"trn2 v13.4s, v8.4s, v9.4s\n" |
|
|
|
"trn2 v15.4s, v10.4s, v11.4s\n" |
|
|
|
|
|
|
|
"trn1 v20.2d, v12.2d, v14.2d\n" |
|
|
|
"trn2 v22.2d, v12.2d, v14.2d\n" |
|
|
|
"trn1 v21.2d, v13.2d, v15.2d\n" |
|
|
|
"trn2 v23.2d, v13.2d, v15.2d\n" |
|
|
|
|
|
|
|
"st1 {v16.8h}, [x11], #16\n" |
|
|
|
"st1 {v17.8h}, [x11], #16\n" |
|
|
|
"st1 {v18.8h}, [x11], #16\n" |
|
|
|
"st1 {v19.8h}, [x11], #16\n" |
|
|
|
"st1 {v20.8h}, [x11], #16\n" |
|
|
|
"st1 {v21.8h}, [x11], #16\n" |
|
|
|
"st1 {v22.8h}, [x11], #16\n" |
|
|
|
"st1 {v23.8h}, [x11], #16\n" |
|
|
|
: |
|
|
|
: [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row) |
|
|
|
: "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", |
|
|
|
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); |
|
|
|
size_t strid_row = row * 2; |
|
|
|
asm volatile( |
|
|
|
"mov x10, %[src_ptr1]\n" |
|
|
|
"mov x11, %[dst_ptr1]\n" |
|
|
|
"mov x12, %[strid_row]\n" |
|
|
|
"ld1 {v0.8h}, [x10], x12\n" |
|
|
|
"ld1 {v1.8h}, [x10], x12\n" |
|
|
|
"ld1 {v2.8h}, [x10], x12\n" |
|
|
|
"ld1 {v3.8h}, [x10], x12\n" |
|
|
|
"ld1 {v4.8h}, [x10], x12\n" |
|
|
|
"ld1 {v5.8h}, [x10], x12\n" |
|
|
|
"ld1 {v6.8h}, [x10], x12\n" |
|
|
|
"ld1 {v7.8h}, [x10], x12\n" |
|
|
|
|
|
|
|
"zip1 v8.8h, v0.8h, v1.8h\n" |
|
|
|
"zip1 v9.8h, v2.8h, v3.8h\n" |
|
|
|
"zip1 v10.8h, v4.8h, v5.8h\n" |
|
|
|
"zip1 v11.8h, v6.8h, v7.8h\n" |
|
|
|
|
|
|
|
"trn1 v12.4s, v8.4s, v9.4s\n" |
|
|
|
"trn1 v14.4s, v10.4s, v11.4s\n" |
|
|
|
"trn2 v13.4s, v8.4s, v9.4s\n" |
|
|
|
"trn2 v15.4s, v10.4s, v11.4s\n" |
|
|
|
|
|
|
|
"trn1 v16.2d, v12.2d, v14.2d\n" |
|
|
|
"trn2 v18.2d, v12.2d, v14.2d\n" |
|
|
|
"trn1 v17.2d, v13.2d, v15.2d\n" |
|
|
|
"trn2 v19.2d, v13.2d, v15.2d\n" |
|
|
|
|
|
|
|
"zip2 v8.8h, v0.8h, v1.8h\n" |
|
|
|
"zip2 v9.8h, v2.8h, v3.8h\n" |
|
|
|
"zip2 v10.8h, v4.8h, v5.8h\n" |
|
|
|
"zip2 v11.8h, v6.8h, v7.8h\n" |
|
|
|
|
|
|
|
"trn1 v12.4s, v8.4s, v9.4s\n" |
|
|
|
"trn1 v14.4s, v10.4s, v11.4s\n" |
|
|
|
"trn2 v13.4s, v8.4s, v9.4s\n" |
|
|
|
"trn2 v15.4s, v10.4s, v11.4s\n" |
|
|
|
|
|
|
|
"trn1 v20.2d, v12.2d, v14.2d\n" |
|
|
|
"trn2 v22.2d, v12.2d, v14.2d\n" |
|
|
|
"trn1 v21.2d, v13.2d, v15.2d\n" |
|
|
|
"trn2 v23.2d, v13.2d, v15.2d\n" |
|
|
|
|
|
|
|
"st1 {v16.8h}, [x11], #16\n" |
|
|
|
"st1 {v17.8h}, [x11], #16\n" |
|
|
|
"st1 {v18.8h}, [x11], #16\n" |
|
|
|
"st1 {v19.8h}, [x11], #16\n" |
|
|
|
"st1 {v20.8h}, [x11], #16\n" |
|
|
|
"st1 {v21.8h}, [x11], #16\n" |
|
|
|
"st1 {v22.8h}, [x11], #16\n" |
|
|
|
"st1 {v23.8h}, [x11], #16\n" |
|
|
|
: |
|
|
|
: [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row) |
|
|
|
: "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", |
|
|
|
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); |
|
|
|
#else |
|
|
|
for (int tr = 0; tr < C8NUM; ++tr) { |
|
|
|
for (int tc = 0; tc < C8NUM; ++tc) { |
|
|
|
dst_ptr1[tr * C8NUM + tc] = src_ptr1[tc * row + tr]; |
|
|
|
} |
|
|
|
} |
|
|
|
#endif |
|
|
|
} |
|
|
|
for (; ri < row; ++ri) { |
|
|
|
const float16_t *src_ptr1 = src + ci * row; |
|
|
|
float16_t *dst_ptr1 = dst_ptr + ci * row; |
|
|
|
for (int tr = 0; tr < C8NUM; ++tr) { |
|
|
|
for (int tc = 0; tc < C8NUM; ++tc) { |
|
|
|
dst_ptr1[ri * C8NUM + tc] = src_ptr1[tc * row + ri]; |
|
|
|
dst_ptr1[tr * C8NUM + tc] = src_ptr1[tc * row + tr]; |
|
|
|
} |
|
|
|
} |
|
|
|
#endif |
|
|
|
} |
|
|
|
for (int r = 0; r < row; r++) { |
|
|
|
for (int tc = ci; tc < col; tc++) { |
|
|
|
int cd8 = tc / C8NUM; |
|
|
|
int cm8 = tc % C8NUM; |
|
|
|
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[tc * row + r]; |
|
|
|
for (; ri < row; ++ri) { |
|
|
|
const float16_t *src_ptr1 = src + ci * row; |
|
|
|
float16_t *dst_ptr1 = dst_ptr + ci * row; |
|
|
|
for (int tc = 0; tc < C8NUM; ++tc) { |
|
|
|
dst_ptr1[ri * C8NUM + tc] = src_ptr1[tc * row + ri]; |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
const float *src = (const float *)src_ptr; |
|
|
|
for (; ci < col_c8; ci += C8NUM) { |
|
|
|
int ri = 0; |
|
|
|
for (; ri < row_c8; ri += C8NUM) { |
|
|
|
const float *src_ptr1 = src + ci * row + ri; |
|
|
|
float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM; |
|
|
|
} |
|
|
|
for (int r = 0; r < row; r++) { |
|
|
|
for (int tc = ci; tc < col; tc++) { |
|
|
|
int cd8 = tc / C8NUM; |
|
|
|
int cm8 = tc % C8NUM; |
|
|
|
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[tc * row + r]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
static void Col2Row8SrcFromFp32(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { |
|
|
|
int row_c8 = row / C8NUM * C8NUM; |
|
|
|
int col_c8 = col / C8NUM * C8NUM; |
|
|
|
int ci = 0; |
|
|
|
const float *src = (const float *)src_ptr; |
|
|
|
for (; ci < col_c8; ci += C8NUM) { |
|
|
|
int ri = 0; |
|
|
|
for (; ri < row_c8; ri += C8NUM) { |
|
|
|
const float *src_ptr1 = src + ci * row + ri; |
|
|
|
float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM; |
|
|
|
#ifdef ENABLE_ARM64 |
|
|
|
size_t strid_row = row * 4; |
|
|
|
asm volatile( |
|
|
|
"mov x10, %[src_ptr1]\n" |
|
|
|
"mov x11, %[dst_ptr1]\n" |
|
|
|
"mov x12, %[strid_row]\n" |
|
|
|
"ld1 {v8.4s, v9.4s}, [x10], x12\n" |
|
|
|
"ld1 {v10.4s, v11.4s}, [x10], x12\n" |
|
|
|
"ld1 {v12.4s, v13.4s}, [x10], x12\n" |
|
|
|
"ld1 {v14.4s, v15.4s}, [x10], x12\n" |
|
|
|
"ld1 {v16.4s, v17.4s}, [x10], x12\n" |
|
|
|
"ld1 {v18.4s, v19.4s}, [x10], x12\n" |
|
|
|
"ld1 {v20.4s, v21.4s}, [x10], x12\n" |
|
|
|
"ld1 {v22.4s, v23.4s}, [x10], x12\n" |
|
|
|
|
|
|
|
"fcvtn v0.4h, v8.4s\n" |
|
|
|
"fcvtn2 v0.8h, v9.4s\n" |
|
|
|
"fcvtn v1.4h, v10.4s\n" |
|
|
|
"fcvtn2 v1.8h, v11.4s\n" |
|
|
|
"fcvtn v2.4h, v12.4s\n" |
|
|
|
"fcvtn2 v2.8h, v13.4s\n" |
|
|
|
"fcvtn v3.4h, v14.4s\n" |
|
|
|
"fcvtn2 v3.8h, v15.4s\n" |
|
|
|
"fcvtn v4.4h, v16.4s\n" |
|
|
|
"fcvtn2 v4.8h, v17.4s\n" |
|
|
|
"fcvtn v5.4h, v18.4s\n" |
|
|
|
"fcvtn2 v5.8h, v19.4s\n" |
|
|
|
"fcvtn v6.4h, v20.4s\n" |
|
|
|
"fcvtn2 v6.8h, v21.4s\n" |
|
|
|
"fcvtn v7.4h, v22.4s\n" |
|
|
|
"fcvtn2 v7.8h, v23.4s\n" |
|
|
|
|
|
|
|
"zip1 v8.8h, v0.8h, v1.8h\n" |
|
|
|
"zip1 v9.8h, v2.8h, v3.8h\n" |
|
|
|
"zip1 v10.8h, v4.8h, v5.8h\n" |
|
|
|
"zip1 v11.8h, v6.8h, v7.8h\n" |
|
|
|
|
|
|
|
"trn1 v12.4s, v8.4s, v9.4s\n" |
|
|
|
"trn1 v14.4s, v10.4s, v11.4s\n" |
|
|
|
"trn2 v13.4s, v8.4s, v9.4s\n" |
|
|
|
"trn2 v15.4s, v10.4s, v11.4s\n" |
|
|
|
|
|
|
|
"trn1 v16.2d, v12.2d, v14.2d\n" |
|
|
|
"trn2 v18.2d, v12.2d, v14.2d\n" |
|
|
|
"trn1 v17.2d, v13.2d, v15.2d\n" |
|
|
|
"trn2 v19.2d, v13.2d, v15.2d\n" |
|
|
|
|
|
|
|
"zip2 v8.8h, v0.8h, v1.8h\n" |
|
|
|
"zip2 v9.8h, v2.8h, v3.8h\n" |
|
|
|
"zip2 v10.8h, v4.8h, v5.8h\n" |
|
|
|
"zip2 v11.8h, v6.8h, v7.8h\n" |
|
|
|
|
|
|
|
"trn1 v12.4s, v8.4s, v9.4s\n" |
|
|
|
"trn1 v14.4s, v10.4s, v11.4s\n" |
|
|
|
"trn2 v13.4s, v8.4s, v9.4s\n" |
|
|
|
"trn2 v15.4s, v10.4s, v11.4s\n" |
|
|
|
|
|
|
|
"trn1 v20.2d, v12.2d, v14.2d\n" |
|
|
|
"trn2 v22.2d, v12.2d, v14.2d\n" |
|
|
|
"trn1 v21.2d, v13.2d, v15.2d\n" |
|
|
|
"trn2 v23.2d, v13.2d, v15.2d\n" |
|
|
|
|
|
|
|
"st1 {v16.8h}, [x11], #16\n" |
|
|
|
"st1 {v17.8h}, [x11], #16\n" |
|
|
|
"st1 {v18.8h}, [x11], #16\n" |
|
|
|
"st1 {v19.8h}, [x11], #16\n" |
|
|
|
"st1 {v20.8h}, [x11], #16\n" |
|
|
|
"st1 {v21.8h}, [x11], #16\n" |
|
|
|
"st1 {v22.8h}, [x11], #16\n" |
|
|
|
"st1 {v23.8h}, [x11], #16\n" |
|
|
|
: |
|
|
|
: [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row) |
|
|
|
: "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", |
|
|
|
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); |
|
|
|
size_t strid_row = row * 4; |
|
|
|
asm volatile( |
|
|
|
"mov x10, %[src_ptr1]\n" |
|
|
|
"mov x11, %[dst_ptr1]\n" |
|
|
|
"mov x12, %[strid_row]\n" |
|
|
|
"ld1 {v8.4s, v9.4s}, [x10], x12\n" |
|
|
|
"ld1 {v10.4s, v11.4s}, [x10], x12\n" |
|
|
|
"ld1 {v12.4s, v13.4s}, [x10], x12\n" |
|
|
|
"ld1 {v14.4s, v15.4s}, [x10], x12\n" |
|
|
|
"ld1 {v16.4s, v17.4s}, [x10], x12\n" |
|
|
|
"ld1 {v18.4s, v19.4s}, [x10], x12\n" |
|
|
|
"ld1 {v20.4s, v21.4s}, [x10], x12\n" |
|
|
|
"ld1 {v22.4s, v23.4s}, [x10], x12\n" |
|
|
|
|
|
|
|
"fcvtn v0.4h, v8.4s\n" |
|
|
|
"fcvtn2 v0.8h, v9.4s\n" |
|
|
|
"fcvtn v1.4h, v10.4s\n" |
|
|
|
"fcvtn2 v1.8h, v11.4s\n" |
|
|
|
"fcvtn v2.4h, v12.4s\n" |
|
|
|
"fcvtn2 v2.8h, v13.4s\n" |
|
|
|
"fcvtn v3.4h, v14.4s\n" |
|
|
|
"fcvtn2 v3.8h, v15.4s\n" |
|
|
|
"fcvtn v4.4h, v16.4s\n" |
|
|
|
"fcvtn2 v4.8h, v17.4s\n" |
|
|
|
"fcvtn v5.4h, v18.4s\n" |
|
|
|
"fcvtn2 v5.8h, v19.4s\n" |
|
|
|
"fcvtn v6.4h, v20.4s\n" |
|
|
|
"fcvtn2 v6.8h, v21.4s\n" |
|
|
|
"fcvtn v7.4h, v22.4s\n" |
|
|
|
"fcvtn2 v7.8h, v23.4s\n" |
|
|
|
|
|
|
|
"zip1 v8.8h, v0.8h, v1.8h\n" |
|
|
|
"zip1 v9.8h, v2.8h, v3.8h\n" |
|
|
|
"zip1 v10.8h, v4.8h, v5.8h\n" |
|
|
|
"zip1 v11.8h, v6.8h, v7.8h\n" |
|
|
|
|
|
|
|
"trn1 v12.4s, v8.4s, v9.4s\n" |
|
|
|
"trn1 v14.4s, v10.4s, v11.4s\n" |
|
|
|
"trn2 v13.4s, v8.4s, v9.4s\n" |
|
|
|
"trn2 v15.4s, v10.4s, v11.4s\n" |
|
|
|
|
|
|
|
"trn1 v16.2d, v12.2d, v14.2d\n" |
|
|
|
"trn2 v18.2d, v12.2d, v14.2d\n" |
|
|
|
"trn1 v17.2d, v13.2d, v15.2d\n" |
|
|
|
"trn2 v19.2d, v13.2d, v15.2d\n" |
|
|
|
|
|
|
|
"zip2 v8.8h, v0.8h, v1.8h\n" |
|
|
|
"zip2 v9.8h, v2.8h, v3.8h\n" |
|
|
|
"zip2 v10.8h, v4.8h, v5.8h\n" |
|
|
|
"zip2 v11.8h, v6.8h, v7.8h\n" |
|
|
|
|
|
|
|
"trn1 v12.4s, v8.4s, v9.4s\n" |
|
|
|
"trn1 v14.4s, v10.4s, v11.4s\n" |
|
|
|
"trn2 v13.4s, v8.4s, v9.4s\n" |
|
|
|
"trn2 v15.4s, v10.4s, v11.4s\n" |
|
|
|
|
|
|
|
"trn1 v20.2d, v12.2d, v14.2d\n" |
|
|
|
"trn2 v22.2d, v12.2d, v14.2d\n" |
|
|
|
"trn1 v21.2d, v13.2d, v15.2d\n" |
|
|
|
"trn2 v23.2d, v13.2d, v15.2d\n" |
|
|
|
|
|
|
|
"st1 {v16.8h}, [x11], #16\n" |
|
|
|
"st1 {v17.8h}, [x11], #16\n" |
|
|
|
"st1 {v18.8h}, [x11], #16\n" |
|
|
|
"st1 {v19.8h}, [x11], #16\n" |
|
|
|
"st1 {v20.8h}, [x11], #16\n" |
|
|
|
"st1 {v21.8h}, [x11], #16\n" |
|
|
|
"st1 {v22.8h}, [x11], #16\n" |
|
|
|
"st1 {v23.8h}, [x11], #16\n" |
|
|
|
: |
|
|
|
: [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row) |
|
|
|
: "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", |
|
|
|
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); |
|
|
|
#else |
|
|
|
for (int tr = 0; tr < C8NUM; ++tr) { |
|
|
|
for (int tc = 0; tc < C8NUM; ++tc) { |
|
|
|
dst_ptr1[tr * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + tr]); |
|
|
|
} |
|
|
|
} |
|
|
|
#endif |
|
|
|
} |
|
|
|
for (; ri < row; ++ri) { |
|
|
|
const float *src_ptr1 = src + ci * row; |
|
|
|
float16_t *dst_ptr1 = dst_ptr + ci * row; |
|
|
|
for (int tr = 0; tr < C8NUM; ++tr) { |
|
|
|
for (int tc = 0; tc < C8NUM; ++tc) { |
|
|
|
dst_ptr1[ri * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + ri]); |
|
|
|
dst_ptr1[tr * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + tr]); |
|
|
|
} |
|
|
|
} |
|
|
|
#endif |
|
|
|
} |
|
|
|
for (int r = 0; r < row; r++) { |
|
|
|
for (int tc = ci; tc < col; tc++) { |
|
|
|
int cd8 = tc / C8NUM; |
|
|
|
int cm8 = tc % C8NUM; |
|
|
|
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = (float16_t)(src[tc * row + r]); |
|
|
|
for (; ri < row; ++ri) { |
|
|
|
const float *src_ptr1 = src + ci * row; |
|
|
|
float16_t *dst_ptr1 = dst_ptr + ci * row; |
|
|
|
for (int tc = 0; tc < C8NUM; ++tc) { |
|
|
|
dst_ptr1[ri * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + ri]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
for (int r = 0; r < row; r++) { |
|
|
|
for (int tc = ci; tc < col; tc++) { |
|
|
|
int cd8 = tc / C8NUM; |
|
|
|
int cm8 = tc % C8NUM; |
|
|
|
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = (float16_t)(src[tc * row + r]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16) { |
|
|
|
if (src_float16) { |
|
|
|
Col2Row8SrcFromFp16(src_ptr, dst_ptr, row, col); |
|
|
|
} else { |
|
|
|
Col2Row8SrcFromFp32(src_ptr, dst_ptr, row, col); |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -274,126 +285,129 @@ void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const f |
|
|
|
MatVecMulFp16Neon64(a, b, c, bias, (int)act_type, depth, col); |
|
|
|
} |
|
|
|
|
|
|
|
static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_t col) { |
|
|
|
size_t stride = col * 2; |
|
|
|
asm volatile( |
|
|
|
"mov x10, %[src_c]\n" |
|
|
|
"mov x11, %[dst_c]\n" |
|
|
|
|
|
|
|
"ld1 {v0.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v1.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v2.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v3.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v4.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v5.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v6.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v7.8h}, [x10], %[stride]\n" |
|
|
|
|
|
|
|
"zip1 v16.8h, v0.8h, v1.8h\n" |
|
|
|
"zip1 v17.8h, v2.8h, v3.8h\n" |
|
|
|
"zip1 v18.8h, v4.8h, v5.8h\n" |
|
|
|
"zip1 v19.8h, v6.8h, v7.8h\n" |
|
|
|
|
|
|
|
"ld1 {v8.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v9.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v10.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v11.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v12.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v13.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v14.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v15.8h}, [x10], %[stride]\n" |
|
|
|
|
|
|
|
"trn1 v20.4s, v16.4s, v17.4s\n" |
|
|
|
"trn2 v21.4s, v16.4s, v17.4s\n" |
|
|
|
"trn1 v22.4s, v18.4s, v19.4s\n" |
|
|
|
"trn2 v23.4s, v18.4s, v19.4s\n" |
|
|
|
|
|
|
|
"trn1 v24.2d, v20.2d, v22.2d\n" |
|
|
|
"trn2 v25.2d, v20.2d, v22.2d\n" |
|
|
|
"trn1 v26.2d, v21.2d, v23.2d\n" |
|
|
|
"trn2 v27.2d, v21.2d, v23.2d\n" |
|
|
|
|
|
|
|
"zip1 v16.8h, v8.8h, v9.8h\n" |
|
|
|
"zip1 v17.8h, v10.8h, v11.8h\n" |
|
|
|
"zip1 v18.8h, v12.8h, v13.8h\n" |
|
|
|
"zip1 v19.8h, v14.8h, v15.8h\n" |
|
|
|
|
|
|
|
"trn1 v20.4s, v16.4s, v17.4s\n" |
|
|
|
"trn2 v21.4s, v16.4s, v17.4s\n" |
|
|
|
"trn1 v22.4s, v18.4s, v19.4s\n" |
|
|
|
"trn2 v23.4s, v18.4s, v19.4s\n" |
|
|
|
|
|
|
|
"trn1 v28.2d, v20.2d, v22.2d\n" |
|
|
|
"trn2 v29.2d, v20.2d, v22.2d\n" |
|
|
|
"trn1 v30.2d, v21.2d, v23.2d\n" |
|
|
|
"trn2 v31.2d, v21.2d, v23.2d\n" |
|
|
|
|
|
|
|
"st1 {v24.8h}, [x11], #16\n" |
|
|
|
"st1 {v28.8h}, [x11], #16\n" |
|
|
|
"st1 {v26.8h}, [x11], #16\n" |
|
|
|
"st1 {v30.8h}, [x11], #16\n" |
|
|
|
"st1 {v25.8h}, [x11], #16\n" |
|
|
|
"st1 {v29.8h}, [x11], #16\n" |
|
|
|
"st1 {v27.8h}, [x11], #16\n" |
|
|
|
"st1 {v31.8h}, [x11], #16\n" |
|
|
|
|
|
|
|
"zip2 v16.8h, v0.8h, v1.8h\n" |
|
|
|
"zip2 v17.8h, v2.8h, v3.8h\n" |
|
|
|
"zip2 v18.8h, v4.8h, v5.8h\n" |
|
|
|
"zip2 v19.8h, v6.8h, v7.8h\n" |
|
|
|
|
|
|
|
"trn1 v20.4s, v16.4s, v17.4s\n" |
|
|
|
"trn2 v21.4s, v16.4s, v17.4s\n" |
|
|
|
"trn1 v22.4s, v18.4s, v19.4s\n" |
|
|
|
"trn2 v23.4s, v18.4s, v19.4s\n" |
|
|
|
|
|
|
|
"trn1 v24.2d, v20.2d, v22.2d\n" |
|
|
|
"trn2 v25.2d, v20.2d, v22.2d\n" |
|
|
|
"trn1 v26.2d, v21.2d, v23.2d\n" |
|
|
|
"trn2 v27.2d, v21.2d, v23.2d\n" |
|
|
|
|
|
|
|
"zip2 v16.8h, v8.8h, v9.8h\n" |
|
|
|
"zip2 v17.8h, v10.8h, v11.8h\n" |
|
|
|
"zip2 v18.8h, v12.8h, v13.8h\n" |
|
|
|
"zip2 v19.8h, v14.8h, v15.8h\n" |
|
|
|
|
|
|
|
"trn1 v20.4s, v16.4s, v17.4s\n" |
|
|
|
"trn2 v21.4s, v16.4s, v17.4s\n" |
|
|
|
"trn1 v22.4s, v18.4s, v19.4s\n" |
|
|
|
"trn2 v23.4s, v18.4s, v19.4s\n" |
|
|
|
|
|
|
|
"trn1 v28.2d, v20.2d, v22.2d\n" |
|
|
|
"trn2 v29.2d, v20.2d, v22.2d\n" |
|
|
|
"trn1 v30.2d, v21.2d, v23.2d\n" |
|
|
|
"trn2 v31.2d, v21.2d, v23.2d\n" |
|
|
|
|
|
|
|
"st1 {v24.8h}, [x11], #16\n" |
|
|
|
"st1 {v28.8h}, [x11], #16\n" |
|
|
|
"st1 {v26.8h}, [x11], #16\n" |
|
|
|
"st1 {v30.8h}, [x11], #16\n" |
|
|
|
"st1 {v25.8h}, [x11], #16\n" |
|
|
|
"st1 {v29.8h}, [x11], #16\n" |
|
|
|
"st1 {v27.8h}, [x11], #16\n" |
|
|
|
"st1 {v31.8h}, [x11], #16\n" |
|
|
|
: |
|
|
|
: [ dst_c ] "r"(dst_ptr), [ src_c ] "r"(src_ptr), [ stride ] "r"(stride) |
|
|
|
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", |
|
|
|
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", |
|
|
|
"v31"); |
|
|
|
} |
|
|
|
|
|
|
|
void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { |
|
|
|
size_t row_up_16 = UP_ROUND(row, C16NUM); |
|
|
|
size_t row16 = row / C16NUM * C16NUM; |
|
|
|
size_t col8 = col / C8NUM * C8NUM; |
|
|
|
const float16_t *src_r = src_ptr; |
|
|
|
float16_t *dst_r = dst_ptr; |
|
|
|
|
|
|
|
size_t ri = 0; |
|
|
|
// find 16 block unit |
|
|
|
for (; ri < row16; ri += C16NUM) { |
|
|
|
size_t ci = 0; |
|
|
|
for (; ci < col8; ci += C8NUM) { |
|
|
|
const float16_t *src_c = src_r + ci; |
|
|
|
float16_t *dst_c = dst_r + ci * C16NUM; |
|
|
|
|
|
|
|
#ifdef ENABLE_ARM64 |
|
|
|
size_t stride = col * 2; |
|
|
|
asm volatile( |
|
|
|
"mov x10, %[src_c]\n" |
|
|
|
"mov x11, %[dst_c]\n" |
|
|
|
|
|
|
|
"ld1 {v0.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v1.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v2.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v3.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v4.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v5.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v6.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v7.8h}, [x10], %[stride]\n" |
|
|
|
|
|
|
|
"zip1 v16.8h, v0.8h, v1.8h\n" |
|
|
|
"zip1 v17.8h, v2.8h, v3.8h\n" |
|
|
|
"zip1 v18.8h, v4.8h, v5.8h\n" |
|
|
|
"zip1 v19.8h, v6.8h, v7.8h\n" |
|
|
|
|
|
|
|
"ld1 {v8.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v9.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v10.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v11.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v12.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v13.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v14.8h}, [x10], %[stride]\n" |
|
|
|
"ld1 {v15.8h}, [x10], %[stride]\n" |
|
|
|
|
|
|
|
"trn1 v20.4s, v16.4s, v17.4s\n" |
|
|
|
"trn2 v21.4s, v16.4s, v17.4s\n" |
|
|
|
"trn1 v22.4s, v18.4s, v19.4s\n" |
|
|
|
"trn2 v23.4s, v18.4s, v19.4s\n" |
|
|
|
|
|
|
|
"trn1 v24.2d, v20.2d, v22.2d\n" |
|
|
|
"trn2 v25.2d, v20.2d, v22.2d\n" |
|
|
|
"trn1 v26.2d, v21.2d, v23.2d\n" |
|
|
|
"trn2 v27.2d, v21.2d, v23.2d\n" |
|
|
|
|
|
|
|
"zip1 v16.8h, v8.8h, v9.8h\n" |
|
|
|
"zip1 v17.8h, v10.8h, v11.8h\n" |
|
|
|
"zip1 v18.8h, v12.8h, v13.8h\n" |
|
|
|
"zip1 v19.8h, v14.8h, v15.8h\n" |
|
|
|
|
|
|
|
"trn1 v20.4s, v16.4s, v17.4s\n" |
|
|
|
"trn2 v21.4s, v16.4s, v17.4s\n" |
|
|
|
"trn1 v22.4s, v18.4s, v19.4s\n" |
|
|
|
"trn2 v23.4s, v18.4s, v19.4s\n" |
|
|
|
|
|
|
|
"trn1 v28.2d, v20.2d, v22.2d\n" |
|
|
|
"trn2 v29.2d, v20.2d, v22.2d\n" |
|
|
|
"trn1 v30.2d, v21.2d, v23.2d\n" |
|
|
|
"trn2 v31.2d, v21.2d, v23.2d\n" |
|
|
|
|
|
|
|
"st1 {v24.8h}, [x11], #16\n" |
|
|
|
"st1 {v28.8h}, [x11], #16\n" |
|
|
|
"st1 {v26.8h}, [x11], #16\n" |
|
|
|
"st1 {v30.8h}, [x11], #16\n" |
|
|
|
"st1 {v25.8h}, [x11], #16\n" |
|
|
|
"st1 {v29.8h}, [x11], #16\n" |
|
|
|
"st1 {v27.8h}, [x11], #16\n" |
|
|
|
"st1 {v31.8h}, [x11], #16\n" |
|
|
|
|
|
|
|
"zip2 v16.8h, v0.8h, v1.8h\n" |
|
|
|
"zip2 v17.8h, v2.8h, v3.8h\n" |
|
|
|
"zip2 v18.8h, v4.8h, v5.8h\n" |
|
|
|
"zip2 v19.8h, v6.8h, v7.8h\n" |
|
|
|
|
|
|
|
"trn1 v20.4s, v16.4s, v17.4s\n" |
|
|
|
"trn2 v21.4s, v16.4s, v17.4s\n" |
|
|
|
"trn1 v22.4s, v18.4s, v19.4s\n" |
|
|
|
"trn2 v23.4s, v18.4s, v19.4s\n" |
|
|
|
|
|
|
|
"trn1 v24.2d, v20.2d, v22.2d\n" |
|
|
|
"trn2 v25.2d, v20.2d, v22.2d\n" |
|
|
|
"trn1 v26.2d, v21.2d, v23.2d\n" |
|
|
|
"trn2 v27.2d, v21.2d, v23.2d\n" |
|
|
|
|
|
|
|
"zip2 v16.8h, v8.8h, v9.8h\n" |
|
|
|
"zip2 v17.8h, v10.8h, v11.8h\n" |
|
|
|
"zip2 v18.8h, v12.8h, v13.8h\n" |
|
|
|
"zip2 v19.8h, v14.8h, v15.8h\n" |
|
|
|
|
|
|
|
"trn1 v20.4s, v16.4s, v17.4s\n" |
|
|
|
"trn2 v21.4s, v16.4s, v17.4s\n" |
|
|
|
"trn1 v22.4s, v18.4s, v19.4s\n" |
|
|
|
"trn2 v23.4s, v18.4s, v19.4s\n" |
|
|
|
|
|
|
|
"trn1 v28.2d, v20.2d, v22.2d\n" |
|
|
|
"trn2 v29.2d, v20.2d, v22.2d\n" |
|
|
|
"trn1 v30.2d, v21.2d, v23.2d\n" |
|
|
|
"trn2 v31.2d, v21.2d, v23.2d\n" |
|
|
|
|
|
|
|
"st1 {v24.8h}, [x11], #16\n" |
|
|
|
"st1 {v28.8h}, [x11], #16\n" |
|
|
|
"st1 {v26.8h}, [x11], #16\n" |
|
|
|
"st1 {v30.8h}, [x11], #16\n" |
|
|
|
"st1 {v25.8h}, [x11], #16\n" |
|
|
|
"st1 {v29.8h}, [x11], #16\n" |
|
|
|
"st1 {v27.8h}, [x11], #16\n" |
|
|
|
"st1 {v31.8h}, [x11], #16\n" |
|
|
|
: |
|
|
|
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) |
|
|
|
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", |
|
|
|
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", |
|
|
|
"v30", "v31"); |
|
|
|
Row2Col16Block16(src_c, dst_c, col); |
|
|
|
#else |
|
|
|
for (int tr = 0; tr < C16NUM; tr++) { |
|
|
|
for (int tc = 0; tc < C8NUM; tc++) { |
|
|
|
@@ -413,7 +427,7 @@ void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, si |
|
|
|
dst_r += C16NUM * col; |
|
|
|
} |
|
|
|
for (; ri < row; ri++) { |
|
|
|
for (size_t i = 0; i < col; i++) { |
|
|
|
for (size_t i = 0; i < col; ++i) { |
|
|
|
dst_r[i * C16NUM] = src_r[i]; |
|
|
|
} |
|
|
|
src_r += col; |
|
|
|
|