From f2a5a81a5d586cbac7e37490bcbf75809447a49a Mon Sep 17 00:00:00 2001 From: Kenji Mouri Date: Wed, 19 Apr 2023 22:53:09 +0800 Subject: [PATCH] Add AVX512F implementation of atan2 in x86 targets. (#4641) --- src/layer/x86/avx512_mathfun.h | 74 +++++++++++++++++++++++++++++----- 1 file changed, 65 insertions(+), 9 deletions(-) diff --git a/src/layer/x86/avx512_mathfun.h b/src/layer/x86/avx512_mathfun.h index be06ab943..e1cc70b70 100644 --- a/src/layer/x86/avx512_mathfun.h +++ b/src/layer/x86/avx512_mathfun.h @@ -779,16 +779,72 @@ static NCNN_FORCEINLINE __m512 atan512_ps(__m512 x) negative_mask); } -static NCNN_FORCEINLINE __m512 atan2512_ps(__m512 a, __m512 b) +// MSVC 2017 x86 CI will be broken if use NCNN_FORCEINLINE for atan2512_ps. +// This function still be inlined compiled by MSVC 2017 even without that. +#if _MSC_VER < 1920 +static __m512 atan2512_ps(__m512 y, __m512 x) +#else +static NCNN_FORCEINLINE __m512 atan2512_ps(__m512 y, __m512 x) +#endif { - //TODO avx512 optimize - float tmpx[16]; - float tmpy[16]; - _mm512_storeu_ps(tmpx, a); - _mm512_storeu_ps(tmpy, b); - for (int i = 0; i < 16; i++) - tmpx[i] = atan2(tmpx[i], tmpy[i]); - return _mm512_loadu_ps(tmpx); + // Reference: https://mazzo.li/posts/vectorized-atan2.html + + const __m512 magic_zero = _mm512_set1_ps(0.0f); + const __m512 magic_negative_zero = _mm512_set1_ps(-0.0f); + const __m512 magic_pi = _mm512_set1_ps(3.1415927f); + const __m512 magic_half_pi = _mm512_set1_ps(1.5707964f); + + // not_equal_zero_x = (x != 0.0f); + __mmask16 not_equal_zero_x = _mm512_cmp_ps_mask( + x, + magic_zero, + _CMP_NEQ_OQ); + + // not_equal_zero_y = (y != 0.0f); + __mmask16 not_equal_zero_y = _mm512_cmp_ps_mask( + y, + magic_zero, + _CMP_NEQ_OQ); + + // normal_mode = ((x != 0.0f) & (y != 0.0f)); + __mmask16 normal_mode = (not_equal_zero_x & not_equal_zero_y); + + // negative_mask_x = magic_negative_zero && x; + __m512 negative_mask_x = _mm512_and_ps(magic_negative_zero, x); + + // negative_mask_y = magic_negative_zero && y; + __m512 negative_mask_y = _mm512_and_ps(magic_negative_zero, y); + + // pi_additions = ((x < 0.0f) ? ((y < 0.0f) ? -PI : PI) : 0.0f); + __m512 pi_additions = _mm512_mask_mov_ps( + magic_zero, + _mm512_cmp_ps_mask(x, magic_zero, _CMP_LT_OQ), + _mm512_mask_mov_ps( + magic_pi, + _mm512_cmp_ps_mask(y, magic_zero, _CMP_LT_OQ), + _mm512_or_ps(magic_negative_zero, magic_pi))); + + // normal_result = (atan(y / x) + pi_additions); + __m512 normal_result = _mm512_add_ps( + atan512_ps(_mm512_div_ps(y, x)), + pi_additions); + + // negative_mask_full_x = ((negative_mask_x | PI) < 0.0f); + __mmask16 negative_mask_full_x = _mm512_cmp_ps_mask( + _mm512_or_ps(negative_mask_x, magic_pi), + magic_zero, + _CMP_LT_OQ); + + // x1 = (negative_mask_y ? -(0.5 * PI) : (0.5 * PI)); + // x2 = (negative_mask_full_x ? PI : 0.0f); + // special_result = ((y != 0.0f) ? x1 : x2); + __m512 special_result = _mm512_mask_mov_ps( + _mm512_mask_mov_ps(magic_zero, negative_mask_full_x, magic_pi), + not_equal_zero_y, + _mm512_or_ps(negative_mask_y, magic_half_pi)); + + // return (normal_mode ? normal_result : special_result); + return _mm512_mask_mov_ps(special_result, normal_mode, normal_result); } #endif // AVX512_MATHFUN_H