Browse Source

Add AVX512F implementation of atan2 in x86 targets. (#4641)

tags/20230517
Kenji Mouri GitHub 3 years ago
parent
commit
f2a5a81a5d
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 65 additions and 9 deletions
  1. +65
    -9
      src/layer/x86/avx512_mathfun.h

+ 65
- 9
src/layer/x86/avx512_mathfun.h View File

@@ -779,16 +779,72 @@ static NCNN_FORCEINLINE __m512 atan512_ps(__m512 x)
negative_mask);
}

static NCNN_FORCEINLINE __m512 atan2512_ps(__m512 a, __m512 b)
// MSVC 2017 x86 CI will be broken if use NCNN_FORCEINLINE for atan2512_ps.
// This function still be inlined compiled by MSVC 2017 even without that.
#if _MSC_VER < 1920
static __m512 atan2512_ps(__m512 y, __m512 x)
#else
static NCNN_FORCEINLINE __m512 atan2512_ps(__m512 y, __m512 x)
#endif
{
//TODO avx512 optimize
float tmpx[16];
float tmpy[16];
_mm512_storeu_ps(tmpx, a);
_mm512_storeu_ps(tmpy, b);
for (int i = 0; i < 16; i++)
tmpx[i] = atan2(tmpx[i], tmpy[i]);
return _mm512_loadu_ps(tmpx);
// Reference: https://mazzo.li/posts/vectorized-atan2.html

const __m512 magic_zero = _mm512_set1_ps(0.0f);
const __m512 magic_negative_zero = _mm512_set1_ps(-0.0f);
const __m512 magic_pi = _mm512_set1_ps(3.1415927f);
const __m512 magic_half_pi = _mm512_set1_ps(1.5707964f);

// not_equal_zero_x = (x != 0.0f);
__mmask16 not_equal_zero_x = _mm512_cmp_ps_mask(
x,
magic_zero,
_CMP_NEQ_OQ);

// not_equal_zero_y = (y != 0.0f);
__mmask16 not_equal_zero_y = _mm512_cmp_ps_mask(
y,
magic_zero,
_CMP_NEQ_OQ);

// normal_mode = ((x != 0.0f) & (y != 0.0f));
__mmask16 normal_mode = (not_equal_zero_x & not_equal_zero_y);

// negative_mask_x = magic_negative_zero && x;
__m512 negative_mask_x = _mm512_and_ps(magic_negative_zero, x);

// negative_mask_y = magic_negative_zero && y;
__m512 negative_mask_y = _mm512_and_ps(magic_negative_zero, y);

// pi_additions = ((x < 0.0f) ? ((y < 0.0f) ? -PI : PI) : 0.0f);
__m512 pi_additions = _mm512_mask_mov_ps(
magic_zero,
_mm512_cmp_ps_mask(x, magic_zero, _CMP_LT_OQ),
_mm512_mask_mov_ps(
magic_pi,
_mm512_cmp_ps_mask(y, magic_zero, _CMP_LT_OQ),
_mm512_or_ps(magic_negative_zero, magic_pi)));

// normal_result = (atan(y / x) + pi_additions);
__m512 normal_result = _mm512_add_ps(
atan512_ps(_mm512_div_ps(y, x)),
pi_additions);

// negative_mask_full_x = ((negative_mask_x | PI) < 0.0f);
__mmask16 negative_mask_full_x = _mm512_cmp_ps_mask(
_mm512_or_ps(negative_mask_x, magic_pi),
magic_zero,
_CMP_LT_OQ);

// x1 = (negative_mask_y ? -(0.5 * PI) : (0.5 * PI));
// x2 = (negative_mask_full_x ? PI : 0.0f);
// special_result = ((y != 0.0f) ? x1 : x2);
__m512 special_result = _mm512_mask_mov_ps(
_mm512_mask_mov_ps(magic_zero, negative_mask_full_x, magic_pi),
not_equal_zero_y,
_mm512_or_ps(negative_mask_y, magic_half_pi));

// return (normal_mode ? normal_result : special_result);
return _mm512_mask_mov_ps(special_result, normal_mode, normal_result);
}

#endif // AVX512_MATHFUN_H

Loading…
Cancel
Save