From 5148224516b6864945fe77b4e57717cfa9f37b2d Mon Sep 17 00:00:00 2001 From: luqiang guo <702572275@qq.com> Date: Tue, 13 Sep 2022 12:42:11 +0800 Subject: [PATCH] optmize softmax arm neon (#4171) --- src/layer/arm/softmax_arm.cpp | 73 +++++++++-------------------------- 1 file changed, 18 insertions(+), 55 deletions(-) diff --git a/src/layer/arm/softmax_arm.cpp b/src/layer/arm/softmax_arm.cpp index c00e3d441..77a0e6964 100644 --- a/src/layer/arm/softmax_arm.cpp +++ b/src/layer/arm/softmax_arm.cpp @@ -76,15 +76,12 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const _sum = vaddq_f32(_sum, vrev64q_f32(_sum)); _sum = vaddq_f32(_sum, vextq_f32(_sum, _sum, 2)); #endif - + float32x4_t _reciprocal_sum = vrecpeq_f32(_sum); + _reciprocal_sum = vmulq_f32(vrecpsq_f32(_sum, _reciprocal_sum), _reciprocal_sum); for (int i = 0; i < w; i++) { float32x4_t _p = vld1q_f32(ptr + i * 4); -#if __aarch64__ - _p = vdivq_f32(_p, _sum); -#else - _p = div_ps(_p, _sum); -#endif + _p = vmulq_f32(_p, _reciprocal_sum); vst1q_f32(ptr + i * 4, _p); } } @@ -152,11 +149,7 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { float32x4_t _p = vld1q_f32(ptr); float32x4_t _sum = vdupq_n_f32(sum[j]); -#if __aarch64__ - _p = vdivq_f32(_p, _sum); -#else _p = div_ps(_p, _sum); -#endif vst1q_f32(ptr, _p); ptr += 4; } @@ -189,14 +182,12 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const _sum = vaddq_f32(_sum, _p); } + float32x4_t _reciprocal_sum = vrecpeq_f32(_sum); + _reciprocal_sum = vmulq_f32(vrecpsq_f32(_sum, _reciprocal_sum), _reciprocal_sum); for (int j = 0; j < w; j++) { float32x4_t _p = vld1q_f32(ptr + j * 4); -#if __aarch64__ - _p = vdivq_f32(_p, _sum); -#else - _p = div_ps(_p, _sum); -#endif + _p = vmulq_f32(_p, _reciprocal_sum); vst1q_f32(ptr + j * 4, _p); } } @@ -269,11 +260,7 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { float32x4_t _p = vld1q_f32(ptr); float32x4_t _sum = vdupq_n_f32(sum[i]); -#if __aarch64__ - _p = vdivq_f32(_p, _sum); -#else _p = div_ps(_p, _sum); -#endif vst1q_f32(ptr, _p); ptr += 4; } @@ -356,11 +343,7 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { float32x4_t _p = vld1q_f32(ptr); float32x4_t _sum = vld1q_f32(sumptr); -#if __aarch64__ - _p = vdivq_f32(_p, _sum); -#else _p = div_ps(_p, _sum); -#endif vst1q_f32(ptr, _p); ptr += 4; sumptr += 4; @@ -398,14 +381,12 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const _sum = vaddq_f32(_sum, _p); } + float32x4_t _reciprocal_sum = vrecpeq_f32(_sum); + _reciprocal_sum = vmulq_f32(vrecpsq_f32(_sum, _reciprocal_sum), _reciprocal_sum); for (int j = 0; j < w; j++) { float32x4_t _p = vld1q_f32(ptr + j * 4); -#if __aarch64__ - _p = vdivq_f32(_p, _sum); -#else - _p = div_ps(_p, _sum); -#endif + _p = vmulq_f32(_p, _reciprocal_sum); vst1q_f32(ptr + j * 4, _p); } @@ -480,14 +461,12 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const int i = 0; #if __ARM_NEON float32x4_t _sum = vdupq_n_f32(sum); + float32x4_t _reciprocal_sum = vrecpeq_f32(_sum); + _reciprocal_sum = vmulq_f32(vrecpsq_f32(_sum, _reciprocal_sum), _reciprocal_sum); for (; i + 3 < w; i += 4) { float32x4_t _p = vld1q_f32(ptr + i); -#if __aarch64__ - _p = vdivq_f32(_p, _sum); -#else - _p = div_ps(_p, _sum); -#endif + _p = vmulq_f32(_p, _reciprocal_sum); vst1q_f32(ptr + i, _p); } #endif // __ARM_NEON @@ -587,11 +566,7 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { float32x4_t _p = vld1q_f32(ptr); float32x4_t _sum = vld1q_f32(psum); -#if __aarch64__ - _p = vdivq_f32(_p, _sum); -#else _p = div_ps(_p, _sum); -#endif vst1q_f32(ptr, _p); ptr += 4; @@ -674,14 +649,12 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const int j = 0; #if __ARM_NEON float32x4_t _sum = vdupq_n_f32(sum); + float32x4_t _reciprocal_sum = vrecpeq_f32(_sum); + _reciprocal_sum = vmulq_f32(vrecpsq_f32(_sum, _reciprocal_sum), _reciprocal_sum); for (; j + 3 < w; j += 4) { float32x4_t _p = vld1q_f32(ptr + j); -#if __aarch64__ - _p = vdivq_f32(_p, _sum); -#else - _p = div_ps(_p, _sum); -#endif + _p = vmulq_f32(_p, _reciprocal_sum); vst1q_f32(ptr + j, _p); } #endif // __ARM_NEON @@ -790,11 +763,7 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { float32x4_t _p = vld1q_f32(ptr); float32x4_t _sum = vld1q_f32(sumptr); -#if __aarch64__ - _p = vdivq_f32(_p, _sum); -#else _p = div_ps(_p, _sum); -#endif // __aarch64__ vst1q_f32(ptr, _p); ptr += 4; @@ -902,11 +871,7 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { float32x4_t _p = vld1q_f32(ptr + j); float32x4_t _sum = vld1q_f32(sumptr + j); -#if __aarch64__ - _p = vdivq_f32(_p, _sum); -#else _p = div_ps(_p, _sum); -#endif vst1q_f32(ptr + j, _p); } #endif // __ARM_NEON @@ -989,14 +954,12 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const int j = 0; #if __ARM_NEON float32x4_t _sum = vdupq_n_f32(sum); + float32x4_t _reciprocal_sum = vrecpeq_f32(_sum); + _reciprocal_sum = vmulq_f32(vrecpsq_f32(_sum, _reciprocal_sum), _reciprocal_sum); for (; j + 3 < w; j += 4) { float32x4_t _p = vld1q_f32(ptr + j); -#if __aarch64__ - _p = vdivq_f32(_p, _sum); -#else - _p = div_ps(_p, _sum); -#endif + _p = vmulq_f32(_p, _reciprocal_sum); vst1q_f32(ptr + j, _p); } #endif // __ARM_NEON