|
|
|
@@ -15,27 +15,57 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "nnacl/fp16/matmul_fp16.h" |
|
|
|
void ColMajor2Row8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { |
|
|
|
for (int r = 0; r < row; r++) { |
|
|
|
for (int c = 0; c < col; c++) { |
|
|
|
int cd8 = c / 8; |
|
|
|
int cm8 = c % 8; |
|
|
|
dst_ptr[cd8 * 8 * row + r * 8 + cm8] = src_ptr[c * row + r]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, |
|
|
|
int deep, int row, int col, int stride, bool write_nhwc) { |
|
|
|
int row_16 = UP_ROUND(row, C16NUM); |
|
|
|
int col_8 = UP_ROUND(col, C8NUM); |
|
|
|
/* col16-major * row8-major => row16x8-major */ |
|
|
|
if (write_nhwc) return; |
|
|
|
for (int r = 0; r < row_16; r++) { |
|
|
|
for (int c = 0; c < col_8; c++) { |
|
|
|
int r16div = r / C16NUM, r16mod = r % C16NUM; |
|
|
|
int c8div = c / C8NUM, c8mod = c % C8NUM; |
|
|
|
size_t ci = c8div * row_16 * C8NUM + r * C8NUM + c8mod; |
|
|
|
float16_t value = 0; |
|
|
|
for (int d = 0; d < deep; d++) { |
|
|
|
size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; |
|
|
|
size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; |
|
|
|
value = value + a[ai] * b[bi]; |
|
|
|
if (write_nhwc) { |
|
|
|
/* col16-major * row8-major => col-major */ |
|
|
|
for (int r = 0; r < row; r++) { |
|
|
|
for (int c = 0; c < col; c++) { |
|
|
|
int r16div = r / C16NUM, r16mod = r % C16NUM; |
|
|
|
int c8div = c / C8NUM, c8mod = c % C8NUM; |
|
|
|
size_t ci = r * stride + c; |
|
|
|
float value = 0; |
|
|
|
for (int d = 0; d < deep; d++) { |
|
|
|
size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; |
|
|
|
size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; |
|
|
|
value = value + a[ai] * b[bi]; |
|
|
|
} |
|
|
|
if (bias != NULL) value += bias[c]; |
|
|
|
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); |
|
|
|
if (act_type != ActType_No) value = MSMAX(0.0f, value); |
|
|
|
dst[ci] = value; |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
/* col16-major * row8-major => row16x8-major */ |
|
|
|
for (int r = 0; r < row_16; r++) { |
|
|
|
for (int c = 0; c < col_8; c++) { |
|
|
|
int r16div = r / C16NUM, r16mod = r % C16NUM; |
|
|
|
int c8div = c / C8NUM, c8mod = c % C8NUM; |
|
|
|
size_t ci = c8div * row_16 * C8NUM + r * C8NUM + c8mod; |
|
|
|
float16_t value = 0; |
|
|
|
for (int d = 0; d < deep; d++) { |
|
|
|
size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; |
|
|
|
size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; |
|
|
|
value = value + a[ai] * b[bi]; |
|
|
|
} |
|
|
|
if (bias != NULL) value += bias[col]; |
|
|
|
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); |
|
|
|
if (act_type != ActType_No) value = MSMAX(0.0f, value); |
|
|
|
dst[ci] = value; |
|
|
|
} |
|
|
|
if (bias != NULL) value += bias[col]; |
|
|
|
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); |
|
|
|
if (act_type != ActType_No) value = MSMAX(0.0f, value); |
|
|
|
dst[ci] = value; |
|
|
|
} |
|
|
|
} |
|
|
|
return; |
|
|
|
@@ -43,12 +73,12 @@ void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const fl |
|
|
|
|
|
|
|
void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, |
|
|
|
int depth, int row, int col, int stride, bool write_nhwc) { |
|
|
|
// MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc); |
|
|
|
MatMul16x8(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc); |
|
|
|
MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc); |
|
|
|
// MatMul16x8(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { |
|
|
|
void RowMajor2Col16MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { |
|
|
|
size_t row16 = row / C16NUM * C16NUM; |
|
|
|
size_t col8 = col / C8NUM * C8NUM; |
|
|
|
float16_t *src_r = src_ptr; |
|
|
|
|