Browse Source

optmize softmax arm neon (#4171)

tags/20221128
luqiang guo GitHub 3 years ago
parent
commit
5148224516
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 55 deletions
  1. +18
    -55
      src/layer/arm/softmax_arm.cpp

+ 18
- 55
src/layer/arm/softmax_arm.cpp View File

@@ -76,15 +76,12 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
_sum = vaddq_f32(_sum, vrev64q_f32(_sum));
_sum = vaddq_f32(_sum, vextq_f32(_sum, _sum, 2));
#endif

float32x4_t _reciprocal_sum = vrecpeq_f32(_sum);
_reciprocal_sum = vmulq_f32(vrecpsq_f32(_sum, _reciprocal_sum), _reciprocal_sum);
for (int i = 0; i < w; i++)
{
float32x4_t _p = vld1q_f32(ptr + i * 4);
#if __aarch64__
_p = vdivq_f32(_p, _sum);
#else
_p = div_ps(_p, _sum);
#endif
_p = vmulq_f32(_p, _reciprocal_sum);
vst1q_f32(ptr + i * 4, _p);
}
}
@@ -152,11 +149,7 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
float32x4_t _p = vld1q_f32(ptr);
float32x4_t _sum = vdupq_n_f32(sum[j]);
#if __aarch64__
_p = vdivq_f32(_p, _sum);
#else
_p = div_ps(_p, _sum);
#endif
vst1q_f32(ptr, _p);
ptr += 4;
}
@@ -189,14 +182,12 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
_sum = vaddq_f32(_sum, _p);
}

float32x4_t _reciprocal_sum = vrecpeq_f32(_sum);
_reciprocal_sum = vmulq_f32(vrecpsq_f32(_sum, _reciprocal_sum), _reciprocal_sum);
for (int j = 0; j < w; j++)
{
float32x4_t _p = vld1q_f32(ptr + j * 4);
#if __aarch64__
_p = vdivq_f32(_p, _sum);
#else
_p = div_ps(_p, _sum);
#endif
_p = vmulq_f32(_p, _reciprocal_sum);
vst1q_f32(ptr + j * 4, _p);
}
}
@@ -269,11 +260,7 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
float32x4_t _p = vld1q_f32(ptr);
float32x4_t _sum = vdupq_n_f32(sum[i]);
#if __aarch64__
_p = vdivq_f32(_p, _sum);
#else
_p = div_ps(_p, _sum);
#endif
vst1q_f32(ptr, _p);
ptr += 4;
}
@@ -356,11 +343,7 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
float32x4_t _p = vld1q_f32(ptr);
float32x4_t _sum = vld1q_f32(sumptr);
#if __aarch64__
_p = vdivq_f32(_p, _sum);
#else
_p = div_ps(_p, _sum);
#endif
vst1q_f32(ptr, _p);
ptr += 4;
sumptr += 4;
@@ -398,14 +381,12 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
_sum = vaddq_f32(_sum, _p);
}

float32x4_t _reciprocal_sum = vrecpeq_f32(_sum);
_reciprocal_sum = vmulq_f32(vrecpsq_f32(_sum, _reciprocal_sum), _reciprocal_sum);
for (int j = 0; j < w; j++)
{
float32x4_t _p = vld1q_f32(ptr + j * 4);
#if __aarch64__
_p = vdivq_f32(_p, _sum);
#else
_p = div_ps(_p, _sum);
#endif
_p = vmulq_f32(_p, _reciprocal_sum);
vst1q_f32(ptr + j * 4, _p);
}

@@ -480,14 +461,12 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
int i = 0;
#if __ARM_NEON
float32x4_t _sum = vdupq_n_f32(sum);
float32x4_t _reciprocal_sum = vrecpeq_f32(_sum);
_reciprocal_sum = vmulq_f32(vrecpsq_f32(_sum, _reciprocal_sum), _reciprocal_sum);
for (; i + 3 < w; i += 4)
{
float32x4_t _p = vld1q_f32(ptr + i);
#if __aarch64__
_p = vdivq_f32(_p, _sum);
#else
_p = div_ps(_p, _sum);
#endif
_p = vmulq_f32(_p, _reciprocal_sum);
vst1q_f32(ptr + i, _p);
}
#endif // __ARM_NEON
@@ -587,11 +566,7 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
float32x4_t _p = vld1q_f32(ptr);
float32x4_t _sum = vld1q_f32(psum);
#if __aarch64__
_p = vdivq_f32(_p, _sum);
#else
_p = div_ps(_p, _sum);
#endif
vst1q_f32(ptr, _p);

ptr += 4;
@@ -674,14 +649,12 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
int j = 0;
#if __ARM_NEON
float32x4_t _sum = vdupq_n_f32(sum);
float32x4_t _reciprocal_sum = vrecpeq_f32(_sum);
_reciprocal_sum = vmulq_f32(vrecpsq_f32(_sum, _reciprocal_sum), _reciprocal_sum);
for (; j + 3 < w; j += 4)
{
float32x4_t _p = vld1q_f32(ptr + j);
#if __aarch64__
_p = vdivq_f32(_p, _sum);
#else
_p = div_ps(_p, _sum);
#endif
_p = vmulq_f32(_p, _reciprocal_sum);
vst1q_f32(ptr + j, _p);
}
#endif // __ARM_NEON
@@ -790,11 +763,7 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
float32x4_t _p = vld1q_f32(ptr);
float32x4_t _sum = vld1q_f32(sumptr);
#if __aarch64__
_p = vdivq_f32(_p, _sum);
#else
_p = div_ps(_p, _sum);
#endif // __aarch64__
vst1q_f32(ptr, _p);

ptr += 4;
@@ -902,11 +871,7 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
float32x4_t _p = vld1q_f32(ptr + j);
float32x4_t _sum = vld1q_f32(sumptr + j);
#if __aarch64__
_p = vdivq_f32(_p, _sum);
#else
_p = div_ps(_p, _sum);
#endif
vst1q_f32(ptr + j, _p);
}
#endif // __ARM_NEON
@@ -989,14 +954,12 @@ int Softmax_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
int j = 0;
#if __ARM_NEON
float32x4_t _sum = vdupq_n_f32(sum);
float32x4_t _reciprocal_sum = vrecpeq_f32(_sum);
_reciprocal_sum = vmulq_f32(vrecpsq_f32(_sum, _reciprocal_sum), _reciprocal_sum);
for (; j + 3 < w; j += 4)
{
float32x4_t _p = vld1q_f32(ptr + j);
#if __aarch64__
_p = vdivq_f32(_p, _sum);
#else
_p = div_ps(_p, _sum);
#endif
_p = vmulq_f32(_p, _reciprocal_sum);
vst1q_f32(ptr + j, _p);
}
#endif // __ARM_NEON


Loading…
Cancel
Save