From 7271ac8a80accd282c42166a5947a7a0a289b8ef Mon Sep 17 00:00:00 2001 From: lixian Date: Wed, 21 Oct 2020 15:01:13 +0800 Subject: [PATCH] add fp32 deconv merge assembly --- mindspore/lite/nnacl/fp32/deconv_winograd.c | 151 +++++++++++++++++++- 1 file changed, 146 insertions(+), 5 deletions(-) diff --git a/mindspore/lite/nnacl/fp32/deconv_winograd.c b/mindspore/lite/nnacl/fp32/deconv_winograd.c index 7228c84215..08d6616bc3 100644 --- a/mindspore/lite/nnacl/fp32/deconv_winograd.c +++ b/mindspore/lite/nnacl/fp32/deconv_winograd.c @@ -159,12 +159,153 @@ void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t #endif void DeConvWgMerge(const float *src, float *dst, size_t src_stride, size_t dst_stride, size_t count) { - for (int i = 0; i < count; ++i) { - const float *s = src + i * src_stride; - float *d = dst + i * dst_stride; - for (int j = 0; j < 4; ++j) { - d[j] += s[j]; + const float *src_ptr = src; + float *dst_ptr = dst; + size_t cuont8 = count / C8NUM * C8NUM; + int i = 0; + for (; i < cuont8; i += 8) { +#ifdef ENABLE_ARM64 + size_t src_step = src_stride * sizeof(float); + size_t dst_step = dst_stride * sizeof(float); + asm volatile( + "mov x7, %[src_ptr]\n" + "mov x8, %[dst_ptr]\n" + "mov x10, x8\n" + + "ld1 {v0.4s}, [x7], %[src_step]\n" + "ld1 {v1.4s}, [x8], %[dst_step]\n" + + "ld1 {v2.4s}, [x7], %[src_step]\n" + "ld1 {v3.4s}, [x8], %[dst_step]\n" + + "fadd v0.4s, v0.4s, v1.4s\n" + "ld1 {v4.4s}, [x7], %[src_step]\n" + "fadd v2.4s, v2.4s, v3.4s\n" + + "st1 {v0.4s}, [x10], %[dst_step]\n" + "st1 {v2.4s}, [x10], %[dst_step]\n" + + "ld1 {v5.4s}, [x8], %[dst_step]\n" + + "ld1 {v6.4s}, [x7], %[src_step]\n" + + "fadd v4.4s, v4.4s, v5.4s\n" + "ld1 {v7.4s}, [x8], %[dst_step]\n" + "fadd v6.4s, v6.4s, v7.4s\n" + + "ld1 {v0.4s}, [x7], %[src_step]\n" + "st1 {v4.4s}, [x10], %[dst_step]\n" + "st1 {v6.4s}, [x10], %[dst_step]\n" + + "ld1 {v1.4s}, [x8], %[dst_step]\n" + + "ld1 {v2.4s}, [x7], %[src_step]\n" + "ld1 {v3.4s}, [x8], %[dst_step]\n" + + "fadd v0.4s, v0.4s, v1.4s\n" + "fadd v2.4s, v2.4s, v3.4s\n" + + "st1 {v0.4s}, [x10], %[dst_step]\n" + "st1 {v2.4s}, [x10], %[dst_step]\n" + + "ld1 {v4.4s}, [x7], %[src_step]\n" + "ld1 {v5.4s}, [x8], %[dst_step]\n" + + "ld1 {v6.4s}, [x7], %[src_step]\n" + "ld1 {v7.4s}, [x8], %[dst_step]\n" + + "fadd v4.4s, v4.4s, v5.4s\n" + "fadd v6.4s, v6.4s, v7.4s\n" + + "st1 {v4.4s}, [x10], %[dst_step]\n" + "st1 {v6.4s}, [x10], %[dst_step]\n" + + : + : [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step) + : "x7", "x8", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); +#elif ENABLE_ARM32 + size_t src_step = src_stride * sizeof(float); + size_t dst_step = dst_stride * sizeof(float); + asm volatile( + "mov r7, %[src_ptr]\n" + "mov r8, %[dst_ptr]\n" + "mov r10, r8\n" + + "vld1.32 {q0}, [r7], %[src_step]\n" + "vld1.32 {q1}, [r8], %[dst_step]\n" + "vld1.32 {q2}, [r7], %[src_step]\n" + "vld1.32 {q3}, [r8], %[dst_step]\n" + + "vadd.f32 q0, q0, q1\n" + "vld1.32 {q8}, [r7], %[src_step]\n" + "vadd.f32 q2, q2, q3\n" + + "vst1.32 {q0}, [r10], %[dst_step]\n" + "vst1.32 {q2}, [r10], %[dst_step]\n" + + "vld1.32 {q9}, [r8], %[dst_step]\n" + + "vld1.32 {q10}, [r7], %[src_step]\n" + + "vadd.f32 q8, q8, q9\n" + "vld1.32 {q11}, [r8], %[dst_step]\n" + "vadd.f32 q10, q10, q11\n" + + "vld1.32 {q0}, [r7], %[src_step]\n" + "vst1.32 {q8}, [r10], %[dst_step]\n" + "vst1.32 {q10}, [r10], %[dst_step]\n" + + "vld1.32 {q1}, [r8], %[dst_step]\n" + + "vld1.32 {q2}, [r7], %[src_step]\n" + "vld1.32 {q3}, [r8], %[dst_step]\n" + + "vadd.f32 q0, q0, q1\n" + "vadd.f32 q2, q2, q3\n" + + "vst1.32 {q0}, [r10], %[dst_step]\n" + "vst1.32 {q2}, [r10], %[dst_step]\n" + + "vld1.32 {q8}, [r7], %[src_step]\n" + "vld1.32 {q9}, [r8], %[dst_step]\n" + + "vld1.32 {q10}, [r7], %[src_step]\n" + "vld1.32 {q11}, [r8], %[dst_step]\n" + + "vadd.f32 q8, q8, q9\n" + "vadd.f32 q10, q10, q11\n" + + "vst1.32 {q8}, [r10], %[dst_step]\n" + "vst1.32 {q10}, [r10], %[dst_step]\n" + + : + : [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step) + : "r7", "r8", "r10", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11"); +#else + for (int j = 0; j < 8; j++) { + const float *s = src_ptr + j * src_stride; + float *d = dst_ptr + j * dst_stride; + for (int k = 0; k < 4; k++) { + d[k] += s[k]; + } + } +#endif + src_ptr += 8 * src_stride; + dst_ptr += 8 * dst_stride; + } + for (; i < count; i++) { +#ifdef ENABLE_ARM + float32x4_t src_data = vld1q_f32(src_ptr); + float32x4_t dst_data = vld1q_f32(dst_ptr); + dst_data = vaddq_f32(src_data, dst_data); + vst1q_f32(dst_ptr, dst_data); +#else + for (int j = 0; j < 4; j++) { + dst_ptr[j] += src_ptr[j]; } +#endif + src_ptr += src_stride; + dst_ptr += dst_stride; } return; }