Browse Source

optimization for fp16 matmul kernel

tags/v1.1.0
lixian 5 years ago
parent
commit
869bffe976
3 changed files with 1319 additions and 1 deletions
  1. +1311
    -0
      mindspore/lite/nnacl/assembly/fp16/MatmulFp16Opt.S
  2. +5
    -1
      mindspore/lite/nnacl/fp16/matmul_fp16.c
  3. +3
    -0
      mindspore/lite/nnacl/fp16/matmul_fp16.h

+ 1311
- 0
mindspore/lite/nnacl/assembly/fp16/MatmulFp16Opt.S
File diff suppressed because it is too large
View File


+ 5
- 1
mindspore/lite/nnacl/fp16/matmul_fp16.c View File

@@ -87,7 +87,11 @@ void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const fl


void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
int depth, int row, int col, int stride, bool write_nhwc) { int depth, int row, int col, int stride, bool write_nhwc) {
MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc);
if (!write_nhwc) {
MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc);
} else {
MatmulFp16Neon64Opt(a, b, c, bias, (int)act_type, depth, row, col, stride, 1);
}
return; return;
} }




+ 3
- 0
mindspore/lite/nnacl/fp16/matmul_fp16.h View File

@@ -39,6 +39,9 @@ void RowMajor2Col16MajorFp16Opt(float16_t *src_ptr, float16_t *dst_ptr, size_t r
void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc); size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc);


void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
size_t depth, size_t row, size_t col, size_t stride, int write_nhwc);

void RowMajor2Col16MajorFp16(void *src, float16_t *dst, int row, int col, bool is_fp32_src); void RowMajor2Col16MajorFp16(void *src, float16_t *dst, int row, int col, bool is_fp32_src);


void RowMajor2Row16MajorFp16(void *src, float16_t *dst, int row, int col, bool is_fp32_src); void RowMajor2Row16MajorFp16(void *src, float16_t *dst, int row, int col, bool is_fp32_src);


Loading…
Cancel
Save