Browse Source

optimize arm neon exp log sincos fma (#3077)

tags/20210720
nihui GitHub 4 years ago
parent
commit
06a7086daa
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 84 additions and 180 deletions
  1. +28
    -60
      src/layer/arm/neon_mathfun.h
  2. +56
    -120
      src/layer/arm/neon_mathfun_fp16s.h

+ 28
- 60
src/layer/arm/neon_mathfun.h View File

@@ -112,35 +112,24 @@ static inline float32x4_t log_ps(float32x4_t x)
float32x4_t z = vmulq_f32(x, x);

float32x4_t y = vdupq_n_f32(c_cephes_log_p0);
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p1));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p2));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p3));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p4));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p5));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p6));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p7));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p8));
y = vmlaq_f32(vdupq_n_f32(c_cephes_log_p1), y, x);
y = vmlaq_f32(vdupq_n_f32(c_cephes_log_p2), y, x);
y = vmlaq_f32(vdupq_n_f32(c_cephes_log_p3), y, x);
y = vmlaq_f32(vdupq_n_f32(c_cephes_log_p4), y, x);
y = vmlaq_f32(vdupq_n_f32(c_cephes_log_p5), y, x);
y = vmlaq_f32(vdupq_n_f32(c_cephes_log_p6), y, x);
y = vmlaq_f32(vdupq_n_f32(c_cephes_log_p7), y, x);
y = vmlaq_f32(vdupq_n_f32(c_cephes_log_p8), y, x);
y = vmulq_f32(y, x);

y = vmulq_f32(y, z);

tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q1));
y = vaddq_f32(y, tmp);
y = vmlaq_f32(y, e, vdupq_n_f32(c_cephes_log_q1));

tmp = vmulq_f32(z, vdupq_n_f32(0.5f));
y = vsubq_f32(y, tmp);
y = vmlsq_f32(y, z, vdupq_n_f32(0.5f));

tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q2));
x = vaddq_f32(x, y);
x = vaddq_f32(x, tmp);
x = vmlaq_f32(x, e, vdupq_n_f32(c_cephes_log_q2));
x = vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(x), invalid_mask)); // negative arg will be NAN
return x;
}
@@ -185,29 +174,16 @@ static inline float32x4_t exp_ps(float32x4_t x)
x = vsubq_f32(x, tmp);
x = vsubq_f32(x, z);

static const float cephes_exp_p[6] = {c_cephes_exp_p0, c_cephes_exp_p1, c_cephes_exp_p2, c_cephes_exp_p3, c_cephes_exp_p4, c_cephes_exp_p5};
float32x4_t y = vld1q_dup_f32(cephes_exp_p + 0);
float32x4_t c1 = vld1q_dup_f32(cephes_exp_p + 1);
float32x4_t c2 = vld1q_dup_f32(cephes_exp_p + 2);
float32x4_t c3 = vld1q_dup_f32(cephes_exp_p + 3);
float32x4_t c4 = vld1q_dup_f32(cephes_exp_p + 4);
float32x4_t c5 = vld1q_dup_f32(cephes_exp_p + 5);

y = vmulq_f32(y, x);
z = vmulq_f32(x, x);

y = vaddq_f32(y, c1);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c2);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c3);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c4);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c5);
float32x4_t y = vdupq_n_f32(c_cephes_exp_p0);
y = vmlaq_f32(vdupq_n_f32(c_cephes_exp_p1), y, x);
y = vmlaq_f32(vdupq_n_f32(c_cephes_exp_p2), y, x);
y = vmlaq_f32(vdupq_n_f32(c_cephes_exp_p3), y, x);
y = vmlaq_f32(vdupq_n_f32(c_cephes_exp_p4), y, x);
y = vmlaq_f32(vdupq_n_f32(c_cephes_exp_p5), y, x);

y = vmulq_f32(y, z);
y = vaddq_f32(y, x);
y = vmlaq_f32(x, y, z);
y = vaddq_f32(y, one);

/* build 2^n */
@@ -250,7 +226,7 @@ static inline float32x4_t exp_ps(float32x4_t x)
static inline void sincos_ps(float32x4_t x, float32x4_t* ysin, float32x4_t* ycos)
{
// any x
float32x4_t xmm1, xmm2, xmm3, y;
float32x4_t y;

uint32x4_t emm2;

@@ -278,12 +254,9 @@ static inline void sincos_ps(float32x4_t x, float32x4_t* ysin, float32x4_t* ycos

/* The magic pass: "Extended precision modular arithmetic"
* x = ((x - y * DP1) - y * DP2) - y * DP3; */
xmm1 = vmulq_n_f32(y, c_minus_cephes_DP1);
xmm2 = vmulq_n_f32(y, c_minus_cephes_DP2);
xmm3 = vmulq_n_f32(y, c_minus_cephes_DP3);
x = vaddq_f32(x, xmm1);
x = vaddq_f32(x, xmm2);
x = vaddq_f32(x, xmm3);
x = vmlaq_f32(x, y, vdupq_n_f32(c_minus_cephes_DP1));
x = vmlaq_f32(x, y, vdupq_n_f32(c_minus_cephes_DP2));
x = vmlaq_f32(x, y, vdupq_n_f32(c_minus_cephes_DP3));

sign_mask_sin = veorq_u32(sign_mask_sin, vtstq_u32(emm2, vdupq_n_u32(4)));
sign_mask_cos = vtstq_u32(vsubq_u32(emm2, vdupq_n_u32(2)), vdupq_n_u32(4));
@@ -293,20 +266,15 @@ static inline void sincos_ps(float32x4_t x, float32x4_t* ysin, float32x4_t* ycos
float32x4_t z = vmulq_f32(x, x);
float32x4_t y1, y2;

y1 = vmulq_n_f32(z, c_coscof_p0);
y2 = vmulq_n_f32(z, c_sincof_p0);
y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p1));
y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p1));
y1 = vmulq_f32(y1, z);
y2 = vmulq_f32(y2, z);
y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p2));
y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p2));
y1 = vmlaq_f32(vdupq_n_f32(c_coscof_p1), z, vdupq_n_f32(c_coscof_p0));
y2 = vmlaq_f32(vdupq_n_f32(c_sincof_p1), z, vdupq_n_f32(c_sincof_p0));
y1 = vmlaq_f32(vdupq_n_f32(c_coscof_p2), y1, z);
y2 = vmlaq_f32(vdupq_n_f32(c_sincof_p2), y2, z);
y1 = vmulq_f32(y1, z);
y2 = vmulq_f32(y2, z);
y1 = vmulq_f32(y1, z);
y2 = vmulq_f32(y2, x);
y1 = vsubq_f32(y1, vmulq_f32(z, vdupq_n_f32(0.5f)));
y2 = vaddq_f32(y2, x);
y1 = vmlsq_f32(y1, z, vdupq_n_f32(0.5f));
y2 = vmlaq_f32(x, y2, x);
y1 = vaddq_f32(y1, vdupq_n_f32(1));

/* select the correct result from the two polynoms */


+ 56
- 120
src/layer/arm/neon_mathfun_fp16s.h View File

@@ -97,35 +97,24 @@ static inline float16x4_t log_ps(float16x4_t x)
float16x4_t z = vmul_f16(x, x);

float16x4_t y = vdup_n_f16(c_cephes_log_p0);
y = vmul_f16(y, x);
y = vadd_f16(y, vdup_n_f16(c_cephes_log_p1));
y = vmul_f16(y, x);
y = vadd_f16(y, vdup_n_f16(c_cephes_log_p2));
y = vmul_f16(y, x);
y = vadd_f16(y, vdup_n_f16(c_cephes_log_p3));
y = vmul_f16(y, x);
y = vadd_f16(y, vdup_n_f16(c_cephes_log_p4));
y = vmul_f16(y, x);
y = vadd_f16(y, vdup_n_f16(c_cephes_log_p5));
y = vmul_f16(y, x);
y = vadd_f16(y, vdup_n_f16(c_cephes_log_p6));
y = vmul_f16(y, x);
y = vadd_f16(y, vdup_n_f16(c_cephes_log_p7));
y = vmul_f16(y, x);
y = vadd_f16(y, vdup_n_f16(c_cephes_log_p8));
y = vfma_f16(vdup_n_f16(c_cephes_log_p1), y, x);
y = vfma_f16(vdup_n_f16(c_cephes_log_p2), y, x);
y = vfma_f16(vdup_n_f16(c_cephes_log_p3), y, x);
y = vfma_f16(vdup_n_f16(c_cephes_log_p4), y, x);
y = vfma_f16(vdup_n_f16(c_cephes_log_p5), y, x);
y = vfma_f16(vdup_n_f16(c_cephes_log_p6), y, x);
y = vfma_f16(vdup_n_f16(c_cephes_log_p7), y, x);
y = vfma_f16(vdup_n_f16(c_cephes_log_p8), y, x);
y = vmul_f16(y, x);

y = vmul_f16(y, z);

tmp = vmul_f16(e, vdup_n_f16(c_cephes_log_q1));
y = vadd_f16(y, tmp);
y = vfma_f16(y, e, vdup_n_f16(c_cephes_log_q1));

tmp = vmul_f16(z, vdup_n_f16(0.5f));
y = vsub_f16(y, tmp);
y = vfms_f16(y, z, vdup_n_f16(0.5f));

tmp = vmul_f16(e, vdup_n_f16(c_cephes_log_q2));
x = vadd_f16(x, y);
x = vadd_f16(x, tmp);
x = vfma_f16(x, e, vdup_n_f16(c_cephes_log_q2));
x = vreinterpret_f16_u16(vorr_u16(vreinterpret_u16_f16(x), invalid_mask)); // negative arg will be NAN
return x;
}
@@ -166,35 +155,24 @@ static inline float16x8_t log_ps(float16x8_t x)
float16x8_t z = vmulq_f16(x, x);

float16x8_t y = vdupq_n_f16(c_cephes_log_p0);
y = vmulq_f16(y, x);
y = vaddq_f16(y, vdupq_n_f16(c_cephes_log_p1));
y = vmulq_f16(y, x);
y = vaddq_f16(y, vdupq_n_f16(c_cephes_log_p2));
y = vmulq_f16(y, x);
y = vaddq_f16(y, vdupq_n_f16(c_cephes_log_p3));
y = vmulq_f16(y, x);
y = vaddq_f16(y, vdupq_n_f16(c_cephes_log_p4));
y = vmulq_f16(y, x);
y = vaddq_f16(y, vdupq_n_f16(c_cephes_log_p5));
y = vmulq_f16(y, x);
y = vaddq_f16(y, vdupq_n_f16(c_cephes_log_p6));
y = vmulq_f16(y, x);
y = vaddq_f16(y, vdupq_n_f16(c_cephes_log_p7));
y = vmulq_f16(y, x);
y = vaddq_f16(y, vdupq_n_f16(c_cephes_log_p8));
y = vfmaq_f16(vdupq_n_f16(c_cephes_log_p1), y, x);
y = vfmaq_f16(vdupq_n_f16(c_cephes_log_p2), y, x);
y = vfmaq_f16(vdupq_n_f16(c_cephes_log_p3), y, x);
y = vfmaq_f16(vdupq_n_f16(c_cephes_log_p4), y, x);
y = vfmaq_f16(vdupq_n_f16(c_cephes_log_p5), y, x);
y = vfmaq_f16(vdupq_n_f16(c_cephes_log_p6), y, x);
y = vfmaq_f16(vdupq_n_f16(c_cephes_log_p7), y, x);
y = vfmaq_f16(vdupq_n_f16(c_cephes_log_p8), y, x);
y = vmulq_f16(y, x);

y = vmulq_f16(y, z);

tmp = vmulq_f16(e, vdupq_n_f16(c_cephes_log_q1));
y = vaddq_f16(y, tmp);
y = vfmaq_f16(y, e, vdupq_n_f16(c_cephes_log_q1));

tmp = vmulq_f16(z, vdupq_n_f16(0.5f));
y = vsubq_f16(y, tmp);
y = vfmsq_f16(y, z, vdupq_n_f16(0.5f));

tmp = vmulq_f16(e, vdupq_n_f16(c_cephes_log_q2));
x = vaddq_f16(x, y);
x = vaddq_f16(x, tmp);
x = vfmaq_f16(x, e, vdupq_n_f16(c_cephes_log_q2));
x = vreinterpretq_f16_u16(vorrq_u16(vreinterpretq_u16_f16(x), invalid_mask)); // negative arg will be NAN
return x;
}
@@ -239,29 +217,16 @@ static inline float16x4_t exp_ps(float16x4_t x)
x = vsub_f16(x, tmp);
x = vsub_f16(x, z);

static const __fp16 cephes_exp_p[6] = {c_cephes_exp_p0, c_cephes_exp_p1, c_cephes_exp_p2, c_cephes_exp_p3, c_cephes_exp_p4, c_cephes_exp_p5};
float16x4_t y = vld1_dup_f16(cephes_exp_p + 0);
float16x4_t c1 = vld1_dup_f16(cephes_exp_p + 1);
float16x4_t c2 = vld1_dup_f16(cephes_exp_p + 2);
float16x4_t c3 = vld1_dup_f16(cephes_exp_p + 3);
float16x4_t c4 = vld1_dup_f16(cephes_exp_p + 4);
float16x4_t c5 = vld1_dup_f16(cephes_exp_p + 5);

y = vmul_f16(y, x);
z = vmul_f16(x, x);

y = vadd_f16(y, c1);
y = vmul_f16(y, x);
y = vadd_f16(y, c2);
y = vmul_f16(y, x);
y = vadd_f16(y, c3);
y = vmul_f16(y, x);
y = vadd_f16(y, c4);
y = vmul_f16(y, x);
y = vadd_f16(y, c5);
float16x4_t y = vdup_n_f16(c_cephes_exp_p0);
y = vfma_f16(vdup_n_f16(c_cephes_exp_p1), y, x);
y = vfma_f16(vdup_n_f16(c_cephes_exp_p2), y, x);
y = vfma_f16(vdup_n_f16(c_cephes_exp_p3), y, x);
y = vfma_f16(vdup_n_f16(c_cephes_exp_p4), y, x);
y = vfma_f16(vdup_n_f16(c_cephes_exp_p5), y, x);

y = vmul_f16(y, z);
y = vadd_f16(y, x);
y = vfma_f16(x, y, z);
y = vadd_f16(y, one);

/* build 2^n */
@@ -300,29 +265,16 @@ static inline float16x8_t exp_ps(float16x8_t x)
x = vsubq_f16(x, tmp);
x = vsubq_f16(x, z);

static const __fp16 cephes_exp_p[6] = {c_cephes_exp_p0, c_cephes_exp_p1, c_cephes_exp_p2, c_cephes_exp_p3, c_cephes_exp_p4, c_cephes_exp_p5};
float16x8_t y = vld1q_dup_f16(cephes_exp_p + 0);
float16x8_t c1 = vld1q_dup_f16(cephes_exp_p + 1);
float16x8_t c2 = vld1q_dup_f16(cephes_exp_p + 2);
float16x8_t c3 = vld1q_dup_f16(cephes_exp_p + 3);
float16x8_t c4 = vld1q_dup_f16(cephes_exp_p + 4);
float16x8_t c5 = vld1q_dup_f16(cephes_exp_p + 5);

y = vmulq_f16(y, x);
z = vmulq_f16(x, x);

y = vaddq_f16(y, c1);
y = vmulq_f16(y, x);
y = vaddq_f16(y, c2);
y = vmulq_f16(y, x);
y = vaddq_f16(y, c3);
y = vmulq_f16(y, x);
y = vaddq_f16(y, c4);
y = vmulq_f16(y, x);
y = vaddq_f16(y, c5);
float16x8_t y = vdupq_n_f16(c_cephes_exp_p0);
y = vfmaq_f16(vdupq_n_f16(c_cephes_exp_p1), y, x);
y = vfmaq_f16(vdupq_n_f16(c_cephes_exp_p2), y, x);
y = vfmaq_f16(vdupq_n_f16(c_cephes_exp_p3), y, x);
y = vfmaq_f16(vdupq_n_f16(c_cephes_exp_p4), y, x);
y = vfmaq_f16(vdupq_n_f16(c_cephes_exp_p5), y, x);

y = vmulq_f16(y, z);
y = vaddq_f16(y, x);
y = vfmaq_f16(x, y, z);
y = vaddq_f16(y, one);

/* build 2^n */
@@ -365,7 +317,7 @@ static inline float16x8_t exp_ps(float16x8_t x)
static inline void sincos_ps(float16x4_t x, float16x4_t* ysin, float16x4_t* ycos)
{
// any x
float16x4_t xmm1, xmm2, xmm3, y;
float16x4_t y;

uint16x4_t emm2;

@@ -393,12 +345,9 @@ static inline void sincos_ps(float16x4_t x, float16x4_t* ysin, float16x4_t* ycos

/* The magic pass: "Extended precision modular arithmetic"
* x = ((x - y * DP1) - y * DP2) - y * DP3; */
xmm1 = vmul_n_f16(y, c_minus_cephes_DP1);
xmm2 = vmul_n_f16(y, c_minus_cephes_DP2);
xmm3 = vmul_n_f16(y, c_minus_cephes_DP3);
x = vadd_f16(x, xmm1);
x = vadd_f16(x, xmm2);
x = vadd_f16(x, xmm3);
x = vfma_f16(x, y, vdup_n_f16(c_minus_cephes_DP1));
x = vfma_f16(x, y, vdup_n_f16(c_minus_cephes_DP2));
x = vfma_f16(x, y, vdup_n_f16(c_minus_cephes_DP3));

sign_mask_sin = veor_u16(sign_mask_sin, vtst_u16(emm2, vdup_n_u16(4)));
sign_mask_cos = vtst_u16(vsub_u16(emm2, vdup_n_u16(2)), vdup_n_u16(4));
@@ -408,20 +357,15 @@ static inline void sincos_ps(float16x4_t x, float16x4_t* ysin, float16x4_t* ycos
float16x4_t z = vmul_f16(x, x);
float16x4_t y1, y2;

y1 = vmul_n_f16(z, c_coscof_p0);
y2 = vmul_n_f16(z, c_sincof_p0);
y1 = vadd_f16(y1, vdup_n_f16(c_coscof_p1));
y2 = vadd_f16(y2, vdup_n_f16(c_sincof_p1));
y1 = vmul_f16(y1, z);
y2 = vmul_f16(y2, z);
y1 = vadd_f16(y1, vdup_n_f16(c_coscof_p2));
y2 = vadd_f16(y2, vdup_n_f16(c_sincof_p2));
y1 = vfma_f16(vdup_n_f16(c_coscof_p1), z, vdup_n_f16(c_coscof_p0));
y2 = vfma_f16(vdup_n_f16(c_sincof_p1), z, vdup_n_f16(c_sincof_p0));
y1 = vfma_f16(vdup_n_f16(c_coscof_p2), y1, z);
y2 = vfma_f16(vdup_n_f16(c_sincof_p2), y2, z);
y1 = vmul_f16(y1, z);
y2 = vmul_f16(y2, z);
y1 = vmul_f16(y1, z);
y2 = vmul_f16(y2, x);
y1 = vsub_f16(y1, vmul_f16(z, vdup_n_f16(0.5f)));
y2 = vadd_f16(y2, x);
y1 = vfms_f16(y1, z, vdup_n_f16(0.5f));
y2 = vfma_f16(x, y2, x);
y1 = vadd_f16(y1, vdup_n_f16(1));

/* select the correct result from the two polynoms */
@@ -434,7 +378,7 @@ static inline void sincos_ps(float16x4_t x, float16x4_t* ysin, float16x4_t* ycos
static inline void sincos_ps(float16x8_t x, float16x8_t* ysin, float16x8_t* ycos)
{
// any x
float16x8_t xmm1, xmm2, xmm3, y;
float16x8_t y;

uint16x8_t emm2;

@@ -462,12 +406,9 @@ static inline void sincos_ps(float16x8_t x, float16x8_t* ysin, float16x8_t* ycos

/* The magic pass: "Extended precision modular arithmetic"
* x = ((x - y * DP1) - y * DP2) - y * DP3; */
xmm1 = vmulq_n_f16(y, c_minus_cephes_DP1);
xmm2 = vmulq_n_f16(y, c_minus_cephes_DP2);
xmm3 = vmulq_n_f16(y, c_minus_cephes_DP3);
x = vaddq_f16(x, xmm1);
x = vaddq_f16(x, xmm2);
x = vaddq_f16(x, xmm3);
x = vfmaq_f16(x, y, vdupq_n_f16(c_minus_cephes_DP1));
x = vfmaq_f16(x, y, vdupq_n_f16(c_minus_cephes_DP2));
x = vfmaq_f16(x, y, vdupq_n_f16(c_minus_cephes_DP3));

sign_mask_sin = veorq_u16(sign_mask_sin, vtstq_u16(emm2, vdupq_n_u16(4)));
sign_mask_cos = vtstq_u16(vsubq_u16(emm2, vdupq_n_u16(2)), vdupq_n_u16(4));
@@ -477,20 +418,15 @@ static inline void sincos_ps(float16x8_t x, float16x8_t* ysin, float16x8_t* ycos
float16x8_t z = vmulq_f16(x, x);
float16x8_t y1, y2;

y1 = vmulq_n_f16(z, c_coscof_p0);
y2 = vmulq_n_f16(z, c_sincof_p0);
y1 = vaddq_f16(y1, vdupq_n_f16(c_coscof_p1));
y2 = vaddq_f16(y2, vdupq_n_f16(c_sincof_p1));
y1 = vmulq_f16(y1, z);
y2 = vmulq_f16(y2, z);
y1 = vaddq_f16(y1, vdupq_n_f16(c_coscof_p2));
y2 = vaddq_f16(y2, vdupq_n_f16(c_sincof_p2));
y1 = vfmaq_f16(vdupq_n_f16(c_coscof_p1), z, vdupq_n_f16(c_coscof_p0));
y2 = vfmaq_f16(vdupq_n_f16(c_sincof_p1), z, vdupq_n_f16(c_sincof_p0));
y1 = vfmaq_f16(vdupq_n_f16(c_coscof_p2), y1, z);
y2 = vfmaq_f16(vdupq_n_f16(c_sincof_p2), y2, z);
y1 = vmulq_f16(y1, z);
y2 = vmulq_f16(y2, z);
y1 = vmulq_f16(y1, z);
y2 = vmulq_f16(y2, x);
y1 = vsubq_f16(y1, vmulq_f16(z, vdupq_n_f16(0.5f)));
y2 = vaddq_f16(y2, x);
y1 = vfmsq_f16(y1, z, vdupq_n_f16(0.5f));
y2 = vfmaq_f16(x, y2, x);
y1 = vaddq_f16(y1, vdupq_n_f16(1));

/* select the correct result from the two polynoms */


Loading…
Cancel
Save