| @@ -779,16 +779,72 @@ static NCNN_FORCEINLINE __m512 atan512_ps(__m512 x) | |||
| negative_mask); | |||
| } | |||
| static NCNN_FORCEINLINE __m512 atan2512_ps(__m512 a, __m512 b) | |||
| // MSVC 2017 x86 CI will be broken if use NCNN_FORCEINLINE for atan2512_ps. | |||
| // This function still be inlined compiled by MSVC 2017 even without that. | |||
| #if _MSC_VER < 1920 | |||
| static __m512 atan2512_ps(__m512 y, __m512 x) | |||
| #else | |||
| static NCNN_FORCEINLINE __m512 atan2512_ps(__m512 y, __m512 x) | |||
| #endif | |||
| { | |||
| //TODO avx512 optimize | |||
| float tmpx[16]; | |||
| float tmpy[16]; | |||
| _mm512_storeu_ps(tmpx, a); | |||
| _mm512_storeu_ps(tmpy, b); | |||
| for (int i = 0; i < 16; i++) | |||
| tmpx[i] = atan2(tmpx[i], tmpy[i]); | |||
| return _mm512_loadu_ps(tmpx); | |||
| // Reference: https://mazzo.li/posts/vectorized-atan2.html | |||
| const __m512 magic_zero = _mm512_set1_ps(0.0f); | |||
| const __m512 magic_negative_zero = _mm512_set1_ps(-0.0f); | |||
| const __m512 magic_pi = _mm512_set1_ps(3.1415927f); | |||
| const __m512 magic_half_pi = _mm512_set1_ps(1.5707964f); | |||
| // not_equal_zero_x = (x != 0.0f); | |||
| __mmask16 not_equal_zero_x = _mm512_cmp_ps_mask( | |||
| x, | |||
| magic_zero, | |||
| _CMP_NEQ_OQ); | |||
| // not_equal_zero_y = (y != 0.0f); | |||
| __mmask16 not_equal_zero_y = _mm512_cmp_ps_mask( | |||
| y, | |||
| magic_zero, | |||
| _CMP_NEQ_OQ); | |||
| // normal_mode = ((x != 0.0f) & (y != 0.0f)); | |||
| __mmask16 normal_mode = (not_equal_zero_x & not_equal_zero_y); | |||
| // negative_mask_x = magic_negative_zero && x; | |||
| __m512 negative_mask_x = _mm512_and_ps(magic_negative_zero, x); | |||
| // negative_mask_y = magic_negative_zero && y; | |||
| __m512 negative_mask_y = _mm512_and_ps(magic_negative_zero, y); | |||
| // pi_additions = ((x < 0.0f) ? ((y < 0.0f) ? -PI : PI) : 0.0f); | |||
| __m512 pi_additions = _mm512_mask_mov_ps( | |||
| magic_zero, | |||
| _mm512_cmp_ps_mask(x, magic_zero, _CMP_LT_OQ), | |||
| _mm512_mask_mov_ps( | |||
| magic_pi, | |||
| _mm512_cmp_ps_mask(y, magic_zero, _CMP_LT_OQ), | |||
| _mm512_or_ps(magic_negative_zero, magic_pi))); | |||
| // normal_result = (atan(y / x) + pi_additions); | |||
| __m512 normal_result = _mm512_add_ps( | |||
| atan512_ps(_mm512_div_ps(y, x)), | |||
| pi_additions); | |||
| // negative_mask_full_x = ((negative_mask_x | PI) < 0.0f); | |||
| __mmask16 negative_mask_full_x = _mm512_cmp_ps_mask( | |||
| _mm512_or_ps(negative_mask_x, magic_pi), | |||
| magic_zero, | |||
| _CMP_LT_OQ); | |||
| // x1 = (negative_mask_y ? -(0.5 * PI) : (0.5 * PI)); | |||
| // x2 = (negative_mask_full_x ? PI : 0.0f); | |||
| // special_result = ((y != 0.0f) ? x1 : x2); | |||
| __m512 special_result = _mm512_mask_mov_ps( | |||
| _mm512_mask_mov_ps(magic_zero, negative_mask_full_x, magic_pi), | |||
| not_equal_zero_y, | |||
| _mm512_or_ps(negative_mask_y, magic_half_pi)); | |||
| // return (normal_mode ? normal_result : special_result); | |||
| return _mm512_mask_mov_ps(special_result, normal_mode, normal_result); | |||
| } | |||
| #endif // AVX512_MATHFUN_H | |||