From d802acd205f8fbaba4edfbe1fd4fca0595f97784 Mon Sep 17 00:00:00 2001 From: Kenji Mouri Date: Tue, 18 Apr 2023 14:14:08 +0800 Subject: [PATCH] Add SSE and AVX implementation of atan2 in x86 targets. (#4633) --- src/layer/x86/avx512_mathfun.h | 24 +++++------ src/layer/x86/avx_mathfun.h | 75 +++++++++++++++++++++++++++------ src/layer/x86/sse_mathfun.h | 76 +++++++++++++++++++++++++++------- 3 files changed, 137 insertions(+), 38 deletions(-) diff --git a/src/layer/x86/avx512_mathfun.h b/src/layer/x86/avx512_mathfun.h index 5b132f8bf..be06ab943 100644 --- a/src/layer/x86/avx512_mathfun.h +++ b/src/layer/x86/avx512_mathfun.h @@ -505,18 +505,6 @@ static NCNN_FORCEINLINE __m512 pow512_ps(__m512 a, __m512 b) return exp512_ps(_mm512_mul_ps(b, log512_ps(a))); } -static NCNN_FORCEINLINE __m512 atan2512_ps(__m512 a, __m512 b) -{ - //TODO avx512 optimize - float tmpx[16]; - float tmpy[16]; - _mm512_storeu_ps(tmpx, a); - _mm512_storeu_ps(tmpy, b); - for (int i = 0; i < 16; i++) - tmpx[i] = atan2(tmpx[i], tmpy[i]); - return _mm512_loadu_ps(tmpx); -} - static NCNN_FORCEINLINE __m512 asin512_ps(__m512 x) { const __m512 magic_negative_zero = _mm512_set1_ps(-0.0f); @@ -791,4 +779,16 @@ static NCNN_FORCEINLINE __m512 atan512_ps(__m512 x) negative_mask); } +static NCNN_FORCEINLINE __m512 atan2512_ps(__m512 a, __m512 b) +{ + //TODO avx512 optimize + float tmpx[16]; + float tmpy[16]; + _mm512_storeu_ps(tmpx, a); + _mm512_storeu_ps(tmpy, b); + for (int i = 0; i < 16; i++) + tmpx[i] = atan2(tmpx[i], tmpy[i]); + return _mm512_loadu_ps(tmpx); +} + #endif // AVX512_MATHFUN_H diff --git a/src/layer/x86/avx_mathfun.h b/src/layer/x86/avx_mathfun.h index 7a1f2a95e..19cb40f6d 100644 --- a/src/layer/x86/avx_mathfun.h +++ b/src/layer/x86/avx_mathfun.h @@ -751,18 +751,6 @@ static NCNN_FORCEINLINE __m256 pow256_ps(__m256 a, __m256 b) return exp256_ps(_mm256_mul_ps(b, log256_ps(a))); } -static NCNN_FORCEINLINE __m256 atan2256_ps(__m256 a, __m256 b) -{ - //TODO avx optimize - float tmpx[8]; - float tmpy[8]; - _mm256_storeu_ps(tmpx, a); - _mm256_storeu_ps(tmpy, b); - for (int i = 0; i < 8; i++) - tmpx[i] = atan2(tmpx[i], tmpy[i]); - return _mm256_loadu_ps(tmpx); -} - static NCNN_FORCEINLINE __m256 asin256_ps(__m256 x) { const __m256 magic_negative_zero = _mm256_set1_ps(-0.0f); @@ -1027,4 +1015,67 @@ static NCNN_FORCEINLINE __m256 atan256_ps(__m256 x) negative_mask); } +static NCNN_FORCEINLINE __m256 atan2256_ps(__m256 y, __m256 x) +{ + // Reference: https://mazzo.li/posts/vectorized-atan2.html + + const __m256 magic_zero = _mm256_set1_ps(0.0f); + const __m256 magic_negative_zero = _mm256_set1_ps(-0.0f); + const __m256 magic_pi = _mm256_set1_ps(3.1415927f); + const __m256 magic_half_pi = _mm256_set1_ps(1.5707964f); + + // not_equal_zero_x = (x != 0.0f); + __m256 not_equal_zero_x = _mm256_cmp_ps(x, magic_zero, _CMP_NEQ_OQ); + + // not_equal_zero_y = (y != 0.0f); + __m256 not_equal_zero_y = _mm256_cmp_ps(y, magic_zero, _CMP_NEQ_OQ); + + // normal_mode = ((x != 0.0f) & (y != 0.0f)); + __m256 normal_mode = _mm256_and_ps(not_equal_zero_x, not_equal_zero_y); + + // negative_mask_x = magic_negative_zero && x; + __m256 negative_mask_x = _mm256_and_ps(magic_negative_zero, x); + + // negative_mask_y = magic_negative_zero && y; + __m256 negative_mask_y = _mm256_and_ps(magic_negative_zero, y); + + // pi_additions = ((x < 0.0f) ? ((y < 0.0f) ? -PI : PI) : 0.0f); + __m256 pi_additions = _mm256_and_ps( + _mm256_cmp_ps(x, magic_zero, _CMP_LT_OQ), + _mm256_or_ps( + _mm256_and_ps( + _mm256_cmp_ps(y, magic_zero, _CMP_LT_OQ), + magic_negative_zero), + magic_pi)); + + // normal_result = (atan(y / x) + pi_additions); + __m256 normal_result = _mm256_add_ps( + atan256_ps(_mm256_div_ps(y, x)), + pi_additions); + + // negative_mask_full_x = ((negative_mask_x | PI) < 0.0f); + __m256 negative_mask_full_x = _mm256_cmp_ps( + _mm256_or_ps(negative_mask_x, magic_pi), + magic_zero, + _CMP_LT_OQ); + + // x1 = (negative_mask_y ? -(0.5 * PI) : (0.5 * PI)); + // x2 = (negative_mask_full_x ? PI : 0.0f); + // special_result = ((y != 0.0f) ? x1 : x2); + __m256 special_result = _mm256_or_ps( + _mm256_and_ps( + not_equal_zero_y, + _mm256_or_ps(negative_mask_y, magic_half_pi)), + _mm256_andnot_ps( + not_equal_zero_y, + _mm256_or_ps( + _mm256_and_ps(negative_mask_full_x, magic_pi), + _mm256_andnot_ps(negative_mask_full_x, magic_zero)))); + + // return (normal_mode ? normal_result : special_result); + return _mm256_or_ps( + _mm256_and_ps(normal_mode, normal_result), + _mm256_andnot_ps(normal_mode, special_result)); +} + #endif // AVX_MATHFUN_H diff --git a/src/layer/x86/sse_mathfun.h b/src/layer/x86/sse_mathfun.h index 3de87691b..fb4b22192 100644 --- a/src/layer/x86/sse_mathfun.h +++ b/src/layer/x86/sse_mathfun.h @@ -738,20 +738,6 @@ static NCNN_FORCEINLINE __m128 pow_ps(__m128 a, __m128 b) return exp_ps(_mm_mul_ps(b, log_ps(a))); } -static NCNN_FORCEINLINE __m128 atan2_ps(__m128 a, __m128 b) -{ - //TODO sse optimize - float tmpx[4]; - float tmpy[4]; - _mm_storeu_ps(tmpx, a); - _mm_storeu_ps(tmpy, b); - tmpx[0] = atan2(tmpx[0], tmpy[0]); - tmpx[1] = atan2(tmpx[1], tmpy[1]); - tmpx[2] = atan2(tmpx[2], tmpy[2]); - tmpx[3] = atan2(tmpx[3], tmpy[3]); - return _mm_loadu_ps(tmpx); -} - static NCNN_FORCEINLINE __m128 ceil_ps(__m128 x) { #if __SSE4_1__ @@ -1100,4 +1086,66 @@ static NCNN_FORCEINLINE __m128 atan_ps(__m128 x) negative_mask); } +static NCNN_FORCEINLINE __m128 atan2_ps(__m128 y, __m128 x) +{ + // Reference: https://mazzo.li/posts/vectorized-atan2.html + + const __m128 magic_zero = _mm_set_ps1(0.0f); + const __m128 magic_negative_zero = _mm_set_ps1(-0.0f); + const __m128 magic_pi = _mm_set_ps1(3.1415927f); + const __m128 magic_half_pi = _mm_set_ps1(1.5707964f); + + // not_equal_zero_x = (x != 0.0f); + __m128 not_equal_zero_x = _mm_cmpneq_ps(x, magic_zero); + + // not_equal_zero_y = (y != 0.0f); + __m128 not_equal_zero_y = _mm_cmpneq_ps(y, magic_zero); + + // normal_mode = ((x != 0.0f) & (y != 0.0f)); + __m128 normal_mode = _mm_and_ps(not_equal_zero_x, not_equal_zero_y); + + // negative_mask_x = magic_negative_zero && x; + __m128 negative_mask_x = _mm_and_ps(magic_negative_zero, x); + + // negative_mask_y = magic_negative_zero && y; + __m128 negative_mask_y = _mm_and_ps(magic_negative_zero, y); + + // pi_additions = ((x < 0.0f) ? ((y < 0.0f) ? -PI : PI) : 0.0f); + __m128 pi_additions = _mm_and_ps( + _mm_cmplt_ps(x, magic_zero), + _mm_or_ps( + _mm_and_ps( + _mm_cmplt_ps(y, magic_zero), + magic_negative_zero), + magic_pi)); + + // normal_result = (atan(y / x) + pi_additions); + __m128 normal_result = _mm_add_ps( + atan_ps(_mm_div_ps(y, x)), + pi_additions); + + // negative_mask_full_x = ((negative_mask_x | PI) < 0.0f); + __m128 negative_mask_full_x = _mm_cmplt_ps( + _mm_or_ps(negative_mask_x, magic_pi), + magic_zero); + + // x1 = (negative_mask_y ? -(0.5 * PI) : (0.5 * PI)); + // x2 = (negative_mask_full_x ? PI : 0.0f); + // special_result = ((y != 0.0f) ? x1 : x2); + __m128 special_result = _mm_or_ps( + _mm_and_ps( + not_equal_zero_y, + _mm_or_ps(negative_mask_y, magic_half_pi)), + _mm_andnot_ps( + not_equal_zero_y, + _mm_or_ps( + _mm_and_ps(negative_mask_full_x, magic_pi), + _mm_andnot_ps(negative_mask_full_x, magic_zero)))); + + // return (normal_mode ? normal_result : special_result); + return _mm_or_ps( + _mm_and_ps(normal_mode, normal_result), + _mm_andnot_ps(normal_mode, special_result)); +} + #endif // SSE_MATHFUN_H