Browse Source

Add SSE and AVX implementation of atan2 in x86 targets. (#4633)

tags/20230517
Kenji Mouri GitHub 3 years ago
parent
commit
d802acd205
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 137 additions and 38 deletions
  1. +12
    -12
      src/layer/x86/avx512_mathfun.h
  2. +63
    -12
      src/layer/x86/avx_mathfun.h
  3. +62
    -14
      src/layer/x86/sse_mathfun.h

+ 12
- 12
src/layer/x86/avx512_mathfun.h View File

@@ -505,18 +505,6 @@ static NCNN_FORCEINLINE __m512 pow512_ps(__m512 a, __m512 b)
return exp512_ps(_mm512_mul_ps(b, log512_ps(a)));
}

static NCNN_FORCEINLINE __m512 atan2512_ps(__m512 a, __m512 b)
{
//TODO avx512 optimize
float tmpx[16];
float tmpy[16];
_mm512_storeu_ps(tmpx, a);
_mm512_storeu_ps(tmpy, b);
for (int i = 0; i < 16; i++)
tmpx[i] = atan2(tmpx[i], tmpy[i]);
return _mm512_loadu_ps(tmpx);
}

static NCNN_FORCEINLINE __m512 asin512_ps(__m512 x)
{
const __m512 magic_negative_zero = _mm512_set1_ps(-0.0f);
@@ -791,4 +779,16 @@ static NCNN_FORCEINLINE __m512 atan512_ps(__m512 x)
negative_mask);
}

static NCNN_FORCEINLINE __m512 atan2512_ps(__m512 a, __m512 b)
{
//TODO avx512 optimize
float tmpx[16];
float tmpy[16];
_mm512_storeu_ps(tmpx, a);
_mm512_storeu_ps(tmpy, b);
for (int i = 0; i < 16; i++)
tmpx[i] = atan2(tmpx[i], tmpy[i]);
return _mm512_loadu_ps(tmpx);
}

#endif // AVX512_MATHFUN_H

+ 63
- 12
src/layer/x86/avx_mathfun.h View File

@@ -751,18 +751,6 @@ static NCNN_FORCEINLINE __m256 pow256_ps(__m256 a, __m256 b)
return exp256_ps(_mm256_mul_ps(b, log256_ps(a)));
}

static NCNN_FORCEINLINE __m256 atan2256_ps(__m256 a, __m256 b)
{
//TODO avx optimize
float tmpx[8];
float tmpy[8];
_mm256_storeu_ps(tmpx, a);
_mm256_storeu_ps(tmpy, b);
for (int i = 0; i < 8; i++)
tmpx[i] = atan2(tmpx[i], tmpy[i]);
return _mm256_loadu_ps(tmpx);
}

static NCNN_FORCEINLINE __m256 asin256_ps(__m256 x)
{
const __m256 magic_negative_zero = _mm256_set1_ps(-0.0f);
@@ -1027,4 +1015,67 @@ static NCNN_FORCEINLINE __m256 atan256_ps(__m256 x)
negative_mask);
}

static NCNN_FORCEINLINE __m256 atan2256_ps(__m256 y, __m256 x)
{
// Reference: https://mazzo.li/posts/vectorized-atan2.html

const __m256 magic_zero = _mm256_set1_ps(0.0f);
const __m256 magic_negative_zero = _mm256_set1_ps(-0.0f);
const __m256 magic_pi = _mm256_set1_ps(3.1415927f);
const __m256 magic_half_pi = _mm256_set1_ps(1.5707964f);

// not_equal_zero_x = (x != 0.0f);
__m256 not_equal_zero_x = _mm256_cmp_ps(x, magic_zero, _CMP_NEQ_OQ);

// not_equal_zero_y = (y != 0.0f);
__m256 not_equal_zero_y = _mm256_cmp_ps(y, magic_zero, _CMP_NEQ_OQ);

// normal_mode = ((x != 0.0f) & (y != 0.0f));
__m256 normal_mode = _mm256_and_ps(not_equal_zero_x, not_equal_zero_y);

// negative_mask_x = magic_negative_zero && x;
__m256 negative_mask_x = _mm256_and_ps(magic_negative_zero, x);

// negative_mask_y = magic_negative_zero && y;
__m256 negative_mask_y = _mm256_and_ps(magic_negative_zero, y);

// pi_additions = ((x < 0.0f) ? ((y < 0.0f) ? -PI : PI) : 0.0f);
__m256 pi_additions = _mm256_and_ps(
_mm256_cmp_ps(x, magic_zero, _CMP_LT_OQ),
_mm256_or_ps(
_mm256_and_ps(
_mm256_cmp_ps(y, magic_zero, _CMP_LT_OQ),
magic_negative_zero),
magic_pi));

// normal_result = (atan(y / x) + pi_additions);
__m256 normal_result = _mm256_add_ps(
atan256_ps(_mm256_div_ps(y, x)),
pi_additions);

// negative_mask_full_x = ((negative_mask_x | PI) < 0.0f);
__m256 negative_mask_full_x = _mm256_cmp_ps(
_mm256_or_ps(negative_mask_x, magic_pi),
magic_zero,
_CMP_LT_OQ);

// x1 = (negative_mask_y ? -(0.5 * PI) : (0.5 * PI));
// x2 = (negative_mask_full_x ? PI : 0.0f);
// special_result = ((y != 0.0f) ? x1 : x2);
__m256 special_result = _mm256_or_ps(
_mm256_and_ps(
not_equal_zero_y,
_mm256_or_ps(negative_mask_y, magic_half_pi)),
_mm256_andnot_ps(
not_equal_zero_y,
_mm256_or_ps(
_mm256_and_ps(negative_mask_full_x, magic_pi),
_mm256_andnot_ps(negative_mask_full_x, magic_zero))));

// return (normal_mode ? normal_result : special_result);
return _mm256_or_ps(
_mm256_and_ps(normal_mode, normal_result),
_mm256_andnot_ps(normal_mode, special_result));
}

#endif // AVX_MATHFUN_H

+ 62
- 14
src/layer/x86/sse_mathfun.h View File

@@ -738,20 +738,6 @@ static NCNN_FORCEINLINE __m128 pow_ps(__m128 a, __m128 b)
return exp_ps(_mm_mul_ps(b, log_ps(a)));
}

static NCNN_FORCEINLINE __m128 atan2_ps(__m128 a, __m128 b)
{
//TODO sse optimize
float tmpx[4];
float tmpy[4];
_mm_storeu_ps(tmpx, a);
_mm_storeu_ps(tmpy, b);
tmpx[0] = atan2(tmpx[0], tmpy[0]);
tmpx[1] = atan2(tmpx[1], tmpy[1]);
tmpx[2] = atan2(tmpx[2], tmpy[2]);
tmpx[3] = atan2(tmpx[3], tmpy[3]);
return _mm_loadu_ps(tmpx);
}

static NCNN_FORCEINLINE __m128 ceil_ps(__m128 x)
{
#if __SSE4_1__
@@ -1100,4 +1086,66 @@ static NCNN_FORCEINLINE __m128 atan_ps(__m128 x)
negative_mask);
}

static NCNN_FORCEINLINE __m128 atan2_ps(__m128 y, __m128 x)
{
// Reference: https://mazzo.li/posts/vectorized-atan2.html

const __m128 magic_zero = _mm_set_ps1(0.0f);
const __m128 magic_negative_zero = _mm_set_ps1(-0.0f);
const __m128 magic_pi = _mm_set_ps1(3.1415927f);
const __m128 magic_half_pi = _mm_set_ps1(1.5707964f);

// not_equal_zero_x = (x != 0.0f);
__m128 not_equal_zero_x = _mm_cmpneq_ps(x, magic_zero);

// not_equal_zero_y = (y != 0.0f);
__m128 not_equal_zero_y = _mm_cmpneq_ps(y, magic_zero);

// normal_mode = ((x != 0.0f) & (y != 0.0f));
__m128 normal_mode = _mm_and_ps(not_equal_zero_x, not_equal_zero_y);

// negative_mask_x = magic_negative_zero && x;
__m128 negative_mask_x = _mm_and_ps(magic_negative_zero, x);

// negative_mask_y = magic_negative_zero && y;
__m128 negative_mask_y = _mm_and_ps(magic_negative_zero, y);

// pi_additions = ((x < 0.0f) ? ((y < 0.0f) ? -PI : PI) : 0.0f);
__m128 pi_additions = _mm_and_ps(
_mm_cmplt_ps(x, magic_zero),
_mm_or_ps(
_mm_and_ps(
_mm_cmplt_ps(y, magic_zero),
magic_negative_zero),
magic_pi));

// normal_result = (atan(y / x) + pi_additions);
__m128 normal_result = _mm_add_ps(
atan_ps(_mm_div_ps(y, x)),
pi_additions);

// negative_mask_full_x = ((negative_mask_x | PI) < 0.0f);
__m128 negative_mask_full_x = _mm_cmplt_ps(
_mm_or_ps(negative_mask_x, magic_pi),
magic_zero);

// x1 = (negative_mask_y ? -(0.5 * PI) : (0.5 * PI));
// x2 = (negative_mask_full_x ? PI : 0.0f);
// special_result = ((y != 0.0f) ? x1 : x2);
__m128 special_result = _mm_or_ps(
_mm_and_ps(
not_equal_zero_y,
_mm_or_ps(negative_mask_y, magic_half_pi)),
_mm_andnot_ps(
not_equal_zero_y,
_mm_or_ps(
_mm_and_ps(negative_mask_full_x, magic_pi),
_mm_andnot_ps(negative_mask_full_x, magic_zero))));

// return (normal_mode ? normal_result : special_result);
return _mm_or_ps(
_mm_and_ps(normal_mode, normal_result),
_mm_andnot_ps(normal_mode, special_result));
}

#endif // SSE_MATHFUN_H

Loading…
Cancel
Save