diff --git a/src/layer/arm/neon_mathfun.h b/src/layer/arm/neon_mathfun.h index 897d01305..daffae56e 100644 --- a/src/layer/arm/neon_mathfun.h +++ b/src/layer/arm/neon_mathfun.h @@ -112,35 +112,24 @@ static inline float32x4_t log_ps(float32x4_t x) float32x4_t z = vmulq_f32(x, x); float32x4_t y = vdupq_n_f32(c_cephes_log_p0); - y = vmulq_f32(y, x); - y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p1)); - y = vmulq_f32(y, x); - y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p2)); - y = vmulq_f32(y, x); - y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p3)); - y = vmulq_f32(y, x); - y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p4)); - y = vmulq_f32(y, x); - y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p5)); - y = vmulq_f32(y, x); - y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p6)); - y = vmulq_f32(y, x); - y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p7)); - y = vmulq_f32(y, x); - y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p8)); + y = vmlaq_f32(vdupq_n_f32(c_cephes_log_p1), y, x); + y = vmlaq_f32(vdupq_n_f32(c_cephes_log_p2), y, x); + y = vmlaq_f32(vdupq_n_f32(c_cephes_log_p3), y, x); + y = vmlaq_f32(vdupq_n_f32(c_cephes_log_p4), y, x); + y = vmlaq_f32(vdupq_n_f32(c_cephes_log_p5), y, x); + y = vmlaq_f32(vdupq_n_f32(c_cephes_log_p6), y, x); + y = vmlaq_f32(vdupq_n_f32(c_cephes_log_p7), y, x); + y = vmlaq_f32(vdupq_n_f32(c_cephes_log_p8), y, x); y = vmulq_f32(y, x); y = vmulq_f32(y, z); - tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q1)); - y = vaddq_f32(y, tmp); + y = vmlaq_f32(y, e, vdupq_n_f32(c_cephes_log_q1)); - tmp = vmulq_f32(z, vdupq_n_f32(0.5f)); - y = vsubq_f32(y, tmp); + y = vmlsq_f32(y, z, vdupq_n_f32(0.5f)); - tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q2)); x = vaddq_f32(x, y); - x = vaddq_f32(x, tmp); + x = vmlaq_f32(x, e, vdupq_n_f32(c_cephes_log_q2)); x = vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(x), invalid_mask)); // negative arg will be NAN return x; } @@ -185,29 +174,16 @@ static inline float32x4_t exp_ps(float32x4_t x) x = vsubq_f32(x, tmp); x = vsubq_f32(x, z); - static const float cephes_exp_p[6] = {c_cephes_exp_p0, c_cephes_exp_p1, c_cephes_exp_p2, c_cephes_exp_p3, c_cephes_exp_p4, c_cephes_exp_p5}; - float32x4_t y = vld1q_dup_f32(cephes_exp_p + 0); - float32x4_t c1 = vld1q_dup_f32(cephes_exp_p + 1); - float32x4_t c2 = vld1q_dup_f32(cephes_exp_p + 2); - float32x4_t c3 = vld1q_dup_f32(cephes_exp_p + 3); - float32x4_t c4 = vld1q_dup_f32(cephes_exp_p + 4); - float32x4_t c5 = vld1q_dup_f32(cephes_exp_p + 5); - - y = vmulq_f32(y, x); z = vmulq_f32(x, x); - y = vaddq_f32(y, c1); - y = vmulq_f32(y, x); - y = vaddq_f32(y, c2); - y = vmulq_f32(y, x); - y = vaddq_f32(y, c3); - y = vmulq_f32(y, x); - y = vaddq_f32(y, c4); - y = vmulq_f32(y, x); - y = vaddq_f32(y, c5); + float32x4_t y = vdupq_n_f32(c_cephes_exp_p0); + y = vmlaq_f32(vdupq_n_f32(c_cephes_exp_p1), y, x); + y = vmlaq_f32(vdupq_n_f32(c_cephes_exp_p2), y, x); + y = vmlaq_f32(vdupq_n_f32(c_cephes_exp_p3), y, x); + y = vmlaq_f32(vdupq_n_f32(c_cephes_exp_p4), y, x); + y = vmlaq_f32(vdupq_n_f32(c_cephes_exp_p5), y, x); - y = vmulq_f32(y, z); - y = vaddq_f32(y, x); + y = vmlaq_f32(x, y, z); y = vaddq_f32(y, one); /* build 2^n */ @@ -250,7 +226,7 @@ static inline float32x4_t exp_ps(float32x4_t x) static inline void sincos_ps(float32x4_t x, float32x4_t* ysin, float32x4_t* ycos) { // any x - float32x4_t xmm1, xmm2, xmm3, y; + float32x4_t y; uint32x4_t emm2; @@ -278,12 +254,9 @@ static inline void sincos_ps(float32x4_t x, float32x4_t* ysin, float32x4_t* ycos /* The magic pass: "Extended precision modular arithmetic" * x = ((x - y * DP1) - y * DP2) - y * DP3; */ - xmm1 = vmulq_n_f32(y, c_minus_cephes_DP1); - xmm2 = vmulq_n_f32(y, c_minus_cephes_DP2); - xmm3 = vmulq_n_f32(y, c_minus_cephes_DP3); - x = vaddq_f32(x, xmm1); - x = vaddq_f32(x, xmm2); - x = vaddq_f32(x, xmm3); + x = vmlaq_f32(x, y, vdupq_n_f32(c_minus_cephes_DP1)); + x = vmlaq_f32(x, y, vdupq_n_f32(c_minus_cephes_DP2)); + x = vmlaq_f32(x, y, vdupq_n_f32(c_minus_cephes_DP3)); sign_mask_sin = veorq_u32(sign_mask_sin, vtstq_u32(emm2, vdupq_n_u32(4))); sign_mask_cos = vtstq_u32(vsubq_u32(emm2, vdupq_n_u32(2)), vdupq_n_u32(4)); @@ -293,20 +266,15 @@ static inline void sincos_ps(float32x4_t x, float32x4_t* ysin, float32x4_t* ycos float32x4_t z = vmulq_f32(x, x); float32x4_t y1, y2; - y1 = vmulq_n_f32(z, c_coscof_p0); - y2 = vmulq_n_f32(z, c_sincof_p0); - y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p1)); - y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p1)); - y1 = vmulq_f32(y1, z); - y2 = vmulq_f32(y2, z); - y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p2)); - y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p2)); + y1 = vmlaq_f32(vdupq_n_f32(c_coscof_p1), z, vdupq_n_f32(c_coscof_p0)); + y2 = vmlaq_f32(vdupq_n_f32(c_sincof_p1), z, vdupq_n_f32(c_sincof_p0)); + y1 = vmlaq_f32(vdupq_n_f32(c_coscof_p2), y1, z); + y2 = vmlaq_f32(vdupq_n_f32(c_sincof_p2), y2, z); y1 = vmulq_f32(y1, z); y2 = vmulq_f32(y2, z); y1 = vmulq_f32(y1, z); - y2 = vmulq_f32(y2, x); - y1 = vsubq_f32(y1, vmulq_f32(z, vdupq_n_f32(0.5f))); - y2 = vaddq_f32(y2, x); + y1 = vmlsq_f32(y1, z, vdupq_n_f32(0.5f)); + y2 = vmlaq_f32(x, y2, x); y1 = vaddq_f32(y1, vdupq_n_f32(1)); /* select the correct result from the two polynoms */ diff --git a/src/layer/arm/neon_mathfun_fp16s.h b/src/layer/arm/neon_mathfun_fp16s.h index 4dcb5814a..c739e8636 100644 --- a/src/layer/arm/neon_mathfun_fp16s.h +++ b/src/layer/arm/neon_mathfun_fp16s.h @@ -97,35 +97,24 @@ static inline float16x4_t log_ps(float16x4_t x) float16x4_t z = vmul_f16(x, x); float16x4_t y = vdup_n_f16(c_cephes_log_p0); - y = vmul_f16(y, x); - y = vadd_f16(y, vdup_n_f16(c_cephes_log_p1)); - y = vmul_f16(y, x); - y = vadd_f16(y, vdup_n_f16(c_cephes_log_p2)); - y = vmul_f16(y, x); - y = vadd_f16(y, vdup_n_f16(c_cephes_log_p3)); - y = vmul_f16(y, x); - y = vadd_f16(y, vdup_n_f16(c_cephes_log_p4)); - y = vmul_f16(y, x); - y = vadd_f16(y, vdup_n_f16(c_cephes_log_p5)); - y = vmul_f16(y, x); - y = vadd_f16(y, vdup_n_f16(c_cephes_log_p6)); - y = vmul_f16(y, x); - y = vadd_f16(y, vdup_n_f16(c_cephes_log_p7)); - y = vmul_f16(y, x); - y = vadd_f16(y, vdup_n_f16(c_cephes_log_p8)); + y = vfma_f16(vdup_n_f16(c_cephes_log_p1), y, x); + y = vfma_f16(vdup_n_f16(c_cephes_log_p2), y, x); + y = vfma_f16(vdup_n_f16(c_cephes_log_p3), y, x); + y = vfma_f16(vdup_n_f16(c_cephes_log_p4), y, x); + y = vfma_f16(vdup_n_f16(c_cephes_log_p5), y, x); + y = vfma_f16(vdup_n_f16(c_cephes_log_p6), y, x); + y = vfma_f16(vdup_n_f16(c_cephes_log_p7), y, x); + y = vfma_f16(vdup_n_f16(c_cephes_log_p8), y, x); y = vmul_f16(y, x); y = vmul_f16(y, z); - tmp = vmul_f16(e, vdup_n_f16(c_cephes_log_q1)); - y = vadd_f16(y, tmp); + y = vfma_f16(y, e, vdup_n_f16(c_cephes_log_q1)); - tmp = vmul_f16(z, vdup_n_f16(0.5f)); - y = vsub_f16(y, tmp); + y = vfms_f16(y, z, vdup_n_f16(0.5f)); - tmp = vmul_f16(e, vdup_n_f16(c_cephes_log_q2)); x = vadd_f16(x, y); - x = vadd_f16(x, tmp); + x = vfma_f16(x, e, vdup_n_f16(c_cephes_log_q2)); x = vreinterpret_f16_u16(vorr_u16(vreinterpret_u16_f16(x), invalid_mask)); // negative arg will be NAN return x; } @@ -166,35 +155,24 @@ static inline float16x8_t log_ps(float16x8_t x) float16x8_t z = vmulq_f16(x, x); float16x8_t y = vdupq_n_f16(c_cephes_log_p0); - y = vmulq_f16(y, x); - y = vaddq_f16(y, vdupq_n_f16(c_cephes_log_p1)); - y = vmulq_f16(y, x); - y = vaddq_f16(y, vdupq_n_f16(c_cephes_log_p2)); - y = vmulq_f16(y, x); - y = vaddq_f16(y, vdupq_n_f16(c_cephes_log_p3)); - y = vmulq_f16(y, x); - y = vaddq_f16(y, vdupq_n_f16(c_cephes_log_p4)); - y = vmulq_f16(y, x); - y = vaddq_f16(y, vdupq_n_f16(c_cephes_log_p5)); - y = vmulq_f16(y, x); - y = vaddq_f16(y, vdupq_n_f16(c_cephes_log_p6)); - y = vmulq_f16(y, x); - y = vaddq_f16(y, vdupq_n_f16(c_cephes_log_p7)); - y = vmulq_f16(y, x); - y = vaddq_f16(y, vdupq_n_f16(c_cephes_log_p8)); + y = vfmaq_f16(vdupq_n_f16(c_cephes_log_p1), y, x); + y = vfmaq_f16(vdupq_n_f16(c_cephes_log_p2), y, x); + y = vfmaq_f16(vdupq_n_f16(c_cephes_log_p3), y, x); + y = vfmaq_f16(vdupq_n_f16(c_cephes_log_p4), y, x); + y = vfmaq_f16(vdupq_n_f16(c_cephes_log_p5), y, x); + y = vfmaq_f16(vdupq_n_f16(c_cephes_log_p6), y, x); + y = vfmaq_f16(vdupq_n_f16(c_cephes_log_p7), y, x); + y = vfmaq_f16(vdupq_n_f16(c_cephes_log_p8), y, x); y = vmulq_f16(y, x); y = vmulq_f16(y, z); - tmp = vmulq_f16(e, vdupq_n_f16(c_cephes_log_q1)); - y = vaddq_f16(y, tmp); + y = vfmaq_f16(y, e, vdupq_n_f16(c_cephes_log_q1)); - tmp = vmulq_f16(z, vdupq_n_f16(0.5f)); - y = vsubq_f16(y, tmp); + y = vfmsq_f16(y, z, vdupq_n_f16(0.5f)); - tmp = vmulq_f16(e, vdupq_n_f16(c_cephes_log_q2)); x = vaddq_f16(x, y); - x = vaddq_f16(x, tmp); + x = vfmaq_f16(x, e, vdupq_n_f16(c_cephes_log_q2)); x = vreinterpretq_f16_u16(vorrq_u16(vreinterpretq_u16_f16(x), invalid_mask)); // negative arg will be NAN return x; } @@ -239,29 +217,16 @@ static inline float16x4_t exp_ps(float16x4_t x) x = vsub_f16(x, tmp); x = vsub_f16(x, z); - static const __fp16 cephes_exp_p[6] = {c_cephes_exp_p0, c_cephes_exp_p1, c_cephes_exp_p2, c_cephes_exp_p3, c_cephes_exp_p4, c_cephes_exp_p5}; - float16x4_t y = vld1_dup_f16(cephes_exp_p + 0); - float16x4_t c1 = vld1_dup_f16(cephes_exp_p + 1); - float16x4_t c2 = vld1_dup_f16(cephes_exp_p + 2); - float16x4_t c3 = vld1_dup_f16(cephes_exp_p + 3); - float16x4_t c4 = vld1_dup_f16(cephes_exp_p + 4); - float16x4_t c5 = vld1_dup_f16(cephes_exp_p + 5); - - y = vmul_f16(y, x); z = vmul_f16(x, x); - y = vadd_f16(y, c1); - y = vmul_f16(y, x); - y = vadd_f16(y, c2); - y = vmul_f16(y, x); - y = vadd_f16(y, c3); - y = vmul_f16(y, x); - y = vadd_f16(y, c4); - y = vmul_f16(y, x); - y = vadd_f16(y, c5); + float16x4_t y = vdup_n_f16(c_cephes_exp_p0); + y = vfma_f16(vdup_n_f16(c_cephes_exp_p1), y, x); + y = vfma_f16(vdup_n_f16(c_cephes_exp_p2), y, x); + y = vfma_f16(vdup_n_f16(c_cephes_exp_p3), y, x); + y = vfma_f16(vdup_n_f16(c_cephes_exp_p4), y, x); + y = vfma_f16(vdup_n_f16(c_cephes_exp_p5), y, x); - y = vmul_f16(y, z); - y = vadd_f16(y, x); + y = vfma_f16(x, y, z); y = vadd_f16(y, one); /* build 2^n */ @@ -300,29 +265,16 @@ static inline float16x8_t exp_ps(float16x8_t x) x = vsubq_f16(x, tmp); x = vsubq_f16(x, z); - static const __fp16 cephes_exp_p[6] = {c_cephes_exp_p0, c_cephes_exp_p1, c_cephes_exp_p2, c_cephes_exp_p3, c_cephes_exp_p4, c_cephes_exp_p5}; - float16x8_t y = vld1q_dup_f16(cephes_exp_p + 0); - float16x8_t c1 = vld1q_dup_f16(cephes_exp_p + 1); - float16x8_t c2 = vld1q_dup_f16(cephes_exp_p + 2); - float16x8_t c3 = vld1q_dup_f16(cephes_exp_p + 3); - float16x8_t c4 = vld1q_dup_f16(cephes_exp_p + 4); - float16x8_t c5 = vld1q_dup_f16(cephes_exp_p + 5); - - y = vmulq_f16(y, x); z = vmulq_f16(x, x); - y = vaddq_f16(y, c1); - y = vmulq_f16(y, x); - y = vaddq_f16(y, c2); - y = vmulq_f16(y, x); - y = vaddq_f16(y, c3); - y = vmulq_f16(y, x); - y = vaddq_f16(y, c4); - y = vmulq_f16(y, x); - y = vaddq_f16(y, c5); + float16x8_t y = vdupq_n_f16(c_cephes_exp_p0); + y = vfmaq_f16(vdupq_n_f16(c_cephes_exp_p1), y, x); + y = vfmaq_f16(vdupq_n_f16(c_cephes_exp_p2), y, x); + y = vfmaq_f16(vdupq_n_f16(c_cephes_exp_p3), y, x); + y = vfmaq_f16(vdupq_n_f16(c_cephes_exp_p4), y, x); + y = vfmaq_f16(vdupq_n_f16(c_cephes_exp_p5), y, x); - y = vmulq_f16(y, z); - y = vaddq_f16(y, x); + y = vfmaq_f16(x, y, z); y = vaddq_f16(y, one); /* build 2^n */ @@ -365,7 +317,7 @@ static inline float16x8_t exp_ps(float16x8_t x) static inline void sincos_ps(float16x4_t x, float16x4_t* ysin, float16x4_t* ycos) { // any x - float16x4_t xmm1, xmm2, xmm3, y; + float16x4_t y; uint16x4_t emm2; @@ -393,12 +345,9 @@ static inline void sincos_ps(float16x4_t x, float16x4_t* ysin, float16x4_t* ycos /* The magic pass: "Extended precision modular arithmetic" * x = ((x - y * DP1) - y * DP2) - y * DP3; */ - xmm1 = vmul_n_f16(y, c_minus_cephes_DP1); - xmm2 = vmul_n_f16(y, c_minus_cephes_DP2); - xmm3 = vmul_n_f16(y, c_minus_cephes_DP3); - x = vadd_f16(x, xmm1); - x = vadd_f16(x, xmm2); - x = vadd_f16(x, xmm3); + x = vfma_f16(x, y, vdup_n_f16(c_minus_cephes_DP1)); + x = vfma_f16(x, y, vdup_n_f16(c_minus_cephes_DP2)); + x = vfma_f16(x, y, vdup_n_f16(c_minus_cephes_DP3)); sign_mask_sin = veor_u16(sign_mask_sin, vtst_u16(emm2, vdup_n_u16(4))); sign_mask_cos = vtst_u16(vsub_u16(emm2, vdup_n_u16(2)), vdup_n_u16(4)); @@ -408,20 +357,15 @@ static inline void sincos_ps(float16x4_t x, float16x4_t* ysin, float16x4_t* ycos float16x4_t z = vmul_f16(x, x); float16x4_t y1, y2; - y1 = vmul_n_f16(z, c_coscof_p0); - y2 = vmul_n_f16(z, c_sincof_p0); - y1 = vadd_f16(y1, vdup_n_f16(c_coscof_p1)); - y2 = vadd_f16(y2, vdup_n_f16(c_sincof_p1)); - y1 = vmul_f16(y1, z); - y2 = vmul_f16(y2, z); - y1 = vadd_f16(y1, vdup_n_f16(c_coscof_p2)); - y2 = vadd_f16(y2, vdup_n_f16(c_sincof_p2)); + y1 = vfma_f16(vdup_n_f16(c_coscof_p1), z, vdup_n_f16(c_coscof_p0)); + y2 = vfma_f16(vdup_n_f16(c_sincof_p1), z, vdup_n_f16(c_sincof_p0)); + y1 = vfma_f16(vdup_n_f16(c_coscof_p2), y1, z); + y2 = vfma_f16(vdup_n_f16(c_sincof_p2), y2, z); y1 = vmul_f16(y1, z); y2 = vmul_f16(y2, z); y1 = vmul_f16(y1, z); - y2 = vmul_f16(y2, x); - y1 = vsub_f16(y1, vmul_f16(z, vdup_n_f16(0.5f))); - y2 = vadd_f16(y2, x); + y1 = vfms_f16(y1, z, vdup_n_f16(0.5f)); + y2 = vfma_f16(x, y2, x); y1 = vadd_f16(y1, vdup_n_f16(1)); /* select the correct result from the two polynoms */ @@ -434,7 +378,7 @@ static inline void sincos_ps(float16x4_t x, float16x4_t* ysin, float16x4_t* ycos static inline void sincos_ps(float16x8_t x, float16x8_t* ysin, float16x8_t* ycos) { // any x - float16x8_t xmm1, xmm2, xmm3, y; + float16x8_t y; uint16x8_t emm2; @@ -462,12 +406,9 @@ static inline void sincos_ps(float16x8_t x, float16x8_t* ysin, float16x8_t* ycos /* The magic pass: "Extended precision modular arithmetic" * x = ((x - y * DP1) - y * DP2) - y * DP3; */ - xmm1 = vmulq_n_f16(y, c_minus_cephes_DP1); - xmm2 = vmulq_n_f16(y, c_minus_cephes_DP2); - xmm3 = vmulq_n_f16(y, c_minus_cephes_DP3); - x = vaddq_f16(x, xmm1); - x = vaddq_f16(x, xmm2); - x = vaddq_f16(x, xmm3); + x = vfmaq_f16(x, y, vdupq_n_f16(c_minus_cephes_DP1)); + x = vfmaq_f16(x, y, vdupq_n_f16(c_minus_cephes_DP2)); + x = vfmaq_f16(x, y, vdupq_n_f16(c_minus_cephes_DP3)); sign_mask_sin = veorq_u16(sign_mask_sin, vtstq_u16(emm2, vdupq_n_u16(4))); sign_mask_cos = vtstq_u16(vsubq_u16(emm2, vdupq_n_u16(2)), vdupq_n_u16(4)); @@ -477,20 +418,15 @@ static inline void sincos_ps(float16x8_t x, float16x8_t* ysin, float16x8_t* ycos float16x8_t z = vmulq_f16(x, x); float16x8_t y1, y2; - y1 = vmulq_n_f16(z, c_coscof_p0); - y2 = vmulq_n_f16(z, c_sincof_p0); - y1 = vaddq_f16(y1, vdupq_n_f16(c_coscof_p1)); - y2 = vaddq_f16(y2, vdupq_n_f16(c_sincof_p1)); - y1 = vmulq_f16(y1, z); - y2 = vmulq_f16(y2, z); - y1 = vaddq_f16(y1, vdupq_n_f16(c_coscof_p2)); - y2 = vaddq_f16(y2, vdupq_n_f16(c_sincof_p2)); + y1 = vfmaq_f16(vdupq_n_f16(c_coscof_p1), z, vdupq_n_f16(c_coscof_p0)); + y2 = vfmaq_f16(vdupq_n_f16(c_sincof_p1), z, vdupq_n_f16(c_sincof_p0)); + y1 = vfmaq_f16(vdupq_n_f16(c_coscof_p2), y1, z); + y2 = vfmaq_f16(vdupq_n_f16(c_sincof_p2), y2, z); y1 = vmulq_f16(y1, z); y2 = vmulq_f16(y2, z); y1 = vmulq_f16(y1, z); - y2 = vmulq_f16(y2, x); - y1 = vsubq_f16(y1, vmulq_f16(z, vdupq_n_f16(0.5f))); - y2 = vaddq_f16(y2, x); + y1 = vfmsq_f16(y1, z, vdupq_n_f16(0.5f)); + y2 = vfmaq_f16(x, y2, x); y1 = vaddq_f16(y1, vdupq_n_f16(1)); /* select the correct result from the two polynoms */