Browse Source

!5182 optimization for winograd matmul

Merge pull request !5182 from lixian/master
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
bd12a37d4f
4 changed files with 154 additions and 6 deletions
  1. +144
    -0
      mindspore/lite/nnacl/assembly/arm64/MatmulFp32OptRemain.S
  2. +2
    -3
      mindspore/lite/nnacl/fp32/conv.c
  3. +7
    -3
      mindspore/lite/nnacl/fp32/matmul.c
  4. +1
    -0
      mindspore/lite/nnacl/fp32/matmul.h

+ 144
- 0
mindspore/lite/nnacl/assembly/arm64/MatmulFp32OptRemain.S View File

@@ -0,0 +1,144 @@
#ifdef __aarch64__
.text
.align 5
.global MatmulFloatNeon64OptRemain
#ifndef __APPLE__
.type MatmulFloatNeon64OptRemain, %function
#endif

// void MatmulFloatNeon64(const float *a, const float *b, float *c, int depth
// int row, int col, size_t stride)
// x0: a
// x1: b
// x2: c
// x3: depth
// x4: row
// x5: col
// x6: stride
// only for winograd
MatmulFloatNeon64OptRemain:
mov x18, #32 // sizeof(float) * 8
mul x9, x3, x18 // block stride of lhs/rhs: sizeof(float) * 8 * depth
mov x18, #4
mul x8, x5, x6
mov x11, #8
mul x11, x11, x6
mul x8, x8, x18
mul x11, x11, x18

cmp x4, #4
ble LoopH4

LoopH8:
mov x10, x4 // reload lhs row
mov x12, x0 // reload lhs ptr
mov x18, x2 // reload dst ptr

LoopW8:
mov x16, x1 // reload rhs ptr
mov x13, x3 // reload depth
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
dup v24.4s, wzr
dup v25.4s, wzr
dup v26.4s, wzr
dup v27.4s, wzr
dup v28.4s, wzr
dup v29.4s, wzr
dup v30.4s, wzr
dup v31.4s, wzr

LoopD8:
ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48
ld1 {v3.4s, v4.4s}, [x16], #32
fmla v16.4s, v3.4s, v0.s[0]
fmla v18.4s, v3.4s, v0.s[1]
fmla v20.4s, v3.4s, v0.s[2]
fmla v22.4s, v3.4s, v0.s[3]
fmla v17.4s, v4.4s, v0.s[0]
fmla v19.4s, v4.4s, v0.s[1]
fmla v21.4s, v4.4s, v0.s[2]
fmla v23.4s, v4.4s, v0.s[3]
fmla v24.4s, v3.4s, v1.s[0]
fmla v26.4s, v3.4s, v1.s[1]
fmla v28.4s, v3.4s, v1.s[2]
fmla v30.4s, v3.4s, v1.s[3]
fmla v25.4s, v4.4s, v1.s[0]
fmla v27.4s, v4.4s, v1.s[1]
fmla v29.4s, v4.4s, v1.s[2]
fmla v31.4s, v4.4s, v1.s[3]

subs w13, w13, #1
bgt LoopD8

st1 {v16.4s, v17.4s}, [x18], x8
st1 {v18.4s, v19.4s}, [x18], x8
st1 {v20.4s, v21.4s}, [x18], x8
st1 {v22.4s, v23.4s}, [x18], x8
st1 {v24.4s, v25.4s}, [x18], x8
st1 {v26.4s, v27.4s}, [x18], x8
st1 {v28.4s, v29.4s}, [x18], x8
st1 {v30.4s, v31.4s}, [x18], x8

subs x10, x10, #8 // lhs row - 8
bgt LoopW8

subs x5, x5, #8 // rhs col - 8
add x1, x1, x9 // rhs ptr + stride
add x2, x2, x11
bgt LoopH8

ret

LoopH4:
mov x10, x4 // reload lhs row
mov x12, x0 // reload lhs ptr
mov x18, x2 // reload dst ptr

LoopW4:
mov x16, x1 // reload rhs ptr
mov x13, x3 // reload depth
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr

LoopD4:
ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48
ld1 {v3.4s, v4.4s}, [x16], #32
fmla v16.4s, v3.4s, v0.s[0]
fmla v18.4s, v3.4s, v0.s[1]
fmla v20.4s, v3.4s, v0.s[2]
fmla v22.4s, v3.4s, v0.s[3]
fmla v17.4s, v4.4s, v0.s[0]
fmla v19.4s, v4.4s, v0.s[1]
fmla v21.4s, v4.4s, v0.s[2]
fmla v23.4s, v4.4s, v0.s[3]

subs x13, x13, #1
bgt LoopD4

st1 {v16.4s, v17.4s}, [x18], x8
st1 {v18.4s, v19.4s}, [x18], x8
st1 {v20.4s, v21.4s}, [x18], x8
st1 {v22.4s, v23.4s}, [x18], x8

subs x10, x10, #4 // lhs row - 4
bgt LoopW4

subs x5, x5, #8 // rhs col - 8
add x1, x1, x9 // rhs ptr + stride
add x2, x2, x11
bgt LoopH4
ret
#endif

+ 2
- 3
mindspore/lite/nnacl/fp32/conv.c View File

@@ -303,7 +303,7 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
for (int i = 0; i < input_unit_square; ++i) {
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM,
C12NUM, oc8 * C8NUM, input_unit_square, 2);
cal_num, oc8 * C8NUM, input_unit_square, 2);
}

// step 4 : output transform
@@ -489,9 +489,8 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
for (int i = 0; i < input_unit_square; ++i) {
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
MatMulOpt(tmp_col_ptr, transed_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0,
ic4 * C4NUM, C12NUM, oc8 * C8NUM, input_unit_square, 2);
ic4 * C4NUM, real_cal_num, oc8 * C8NUM, input_unit_square, 2);
}

Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out + nc4hw4_buffer_offset,
bias_data, start_index, real_cal_num, out_w_block, conv_param);
}


+ 7
- 3
mindspore/lite/nnacl/fp32/matmul.c View File

@@ -386,7 +386,7 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
float value = 0;
for (int d = 0; d < deep; ++d) {
size_t ai = src_r_offset + d * row;
size_t ai = src_r_offset + d * C12NUM;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
@@ -403,8 +403,12 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
int col, size_t stride, int out_type) {
#ifdef ENABLE_ARM64
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
(int)(out_type == OutType_TileC8));
if (out_type == 2 && row <= 8) {
MatmulFloatNeon64OptRemain(a, b, c, deep, row, col, stride);
} else {
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
(int)(out_type == OutType_TileC8));
}
#else
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
#endif


+ 1
- 0
mindspore/lite/nnacl/fp32/matmul.h View File

@@ -39,6 +39,7 @@ void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bi
int col, size_t stride, bool write_nhwc);
void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, size_t stride, size_t write_nhwc, size_t write_c4);
void MatmulFloatNeon64OptRemain(const float *a, const float *b, float *c, int depth, int row, int col, size_t stride);
#endif
#ifdef __cplusplus
}


Loading…
Cancel
Save