diff --git a/src/layer/x86/sse_mathfun.h b/src/layer/x86/sse_mathfun.h index 54d091dd4..4df439323 100644 --- a/src/layer/x86/sse_mathfun.h +++ b/src/layer/x86/sse_mathfun.h @@ -752,4 +752,88 @@ static NCNN_FORCEINLINE __m128 atan2_ps(__m128 a, __m128 b) return _mm_loadu_ps(tmpx); } +static NCNN_FORCEINLINE __m128 ceil_ps(__m128 x) +{ +#if __SSE4_1__ + return _mm_ceil_ps(x); +#endif // __SSE4_1__ + + // Use negative zero as the sign bit mask. + const __m128 magic_negative_zero = _mm_set_ps1(-0.0f); + + // The smallest float number that have no fractional part. (2^23) + const __m128 magic_smallest_no_fraction = _mm_set_ps1(8388608.0f); + + // absolute = abs(x); + __m128 absolute = _mm_andnot_ps(magic_negative_zero, x); + + // negative_mask = magic_negative_zero && x; + __m128 negative_mask = _mm_and_ps(magic_negative_zero, x); + + // no_fraction = (magic_smallest_no_fraction < absolute); + __m128 no_fraction = _mm_cmplt_ps(magic_smallest_no_fraction, absolute); + + // truncated = static_cast(static_cast(absolute)); + __m128 truncated = _mm_cvtepi32_ps(_mm_cvttps_epi32(absolute)); + + // truncated_with_sign = (truncated || negative_mask); + __m128 truncated_with_sign = _mm_or_ps(truncated, negative_mask); + + // positive_fix = ((x > -0.0f) && (x > truncated_with_sign) ? -1.0f : 0.0f); + __m128 positive_fix = _mm_and_ps( + _mm_and_ps( + _mm_cmpgt_ps(x, magic_negative_zero), + _mm_cmpgt_ps(x, truncated_with_sign)), + _mm_set_ps1(-1.0f)); + + // fixed_result = truncated_with_sign - positive_fix; + __m128 fixed_result = _mm_sub_ps(truncated_with_sign, positive_fix); + + // return ((x && no_fraction) || (!no_fraction && fixed_result)); + return _mm_or_ps( + _mm_and_ps(x, no_fraction), + _mm_andnot_ps(no_fraction, fixed_result)); +} + +static NCNN_FORCEINLINE __m128 floor_ps(__m128 x) +{ +#if __SSE4_1__ + return _mm_floor_ps(x); +#endif // __SSE4_1__ + + // Use negative zero as the sign bit mask. + const __m128 magic_negative_zero = _mm_set_ps1(-0.0f); + + // The smallest float number that have no fractional part. (2^23) + const __m128 magic_smallest_no_fraction = _mm_set_ps1(8388608.0f); + + // absolute = abs(x); + __m128 absolute = _mm_andnot_ps(magic_negative_zero, x); + + // negative_mask = magic_negative_zero && x; + __m128 negative_mask = _mm_and_ps(magic_negative_zero, x); + + // no_fraction = (magic_smallest_no_fraction < absolute); + __m128 no_fraction = _mm_cmplt_ps(magic_smallest_no_fraction, absolute); + + // truncated = static_cast(static_cast(absolute)); + __m128 truncated = _mm_cvtepi32_ps(_mm_cvttps_epi32(absolute)); + + // truncated_with_sign = (truncated || negative_mask); + __m128 truncated_with_sign = _mm_or_ps(truncated, negative_mask); + + // negative_fix = ((x < truncated_with_sign) ? 1.0f : 0.0f); + __m128 negative_fix = _mm_and_ps( + _mm_cmplt_ps(x, truncated_with_sign), + _mm_set_ps1(1.0f)); + + // fixed_result = truncated_with_sign - negative_fix; + __m128 fixed_result = _mm_sub_ps(truncated_with_sign, negative_fix); + + // return ((x && no_fraction) || (!no_fraction && fixed_result)); + return _mm_or_ps( + _mm_and_ps(x, no_fraction), + _mm_andnot_ps(no_fraction, fixed_result)); +} + #endif // SSE_MATHFUN_H diff --git a/src/layer/x86/unaryop_x86.cpp b/src/layer/x86/unaryop_x86.cpp index faba25dbd..987eaffcf 100644 --- a/src/layer/x86/unaryop_x86.cpp +++ b/src/layer/x86/unaryop_x86.cpp @@ -158,43 +158,7 @@ struct unary_op_floor #if __SSE2__ __m128 func_pack4(const __m128& x) const { -#if __SSE4_1__ - return _mm_floor_ps(x); -#endif // __SSE4_1__ - - // Use negative zero as the sign bit mask. - const __m128 magic_negative_zero = _mm_set_ps1(-0.0f); - - // The smallest float number that have no fractional part. (2^23) - const __m128 magic_smallest_no_fraction = _mm_set_ps1(8388608.0f); - - // absolute = abs(x); - __m128 absolute = _mm_andnot_ps(magic_negative_zero, x); - - // negative_mask = magic_negative_zero && x; - __m128 negative_mask = _mm_and_ps(magic_negative_zero, x); - - // no_fraction = (magic_smallest_no_fraction < absolute); - __m128 no_fraction = _mm_cmplt_ps(magic_smallest_no_fraction, absolute); - - // truncated = static_cast(static_cast(absolute)); - __m128 truncated = _mm_cvtepi32_ps(_mm_cvttps_epi32(absolute)); - - // truncated_with_sign = (truncated || negative_mask); - __m128 truncated_with_sign = _mm_or_ps(truncated, negative_mask); - - // negative_fix = ((x < truncated_with_sign) ? 1.0f : 0.0f); - __m128 negative_fix = _mm_and_ps( - _mm_cmplt_ps(x, truncated_with_sign), - _mm_set_ps1(1.0f)); - - // fixed_result = truncated_with_sign - negative_fix; - __m128 fixed_result = _mm_sub_ps(truncated_with_sign, negative_fix); - - // return ((x && no_fraction) || (!no_fraction && fixed_result)); - return _mm_or_ps( - _mm_and_ps(x, no_fraction), - _mm_andnot_ps(no_fraction, fixed_result)); + return floor_ps(x); } #if __AVX__ __m256 func_pack8(const __m256& x) const @@ -220,45 +184,7 @@ struct unary_op_ceil #if __SSE2__ __m128 func_pack4(const __m128& x) const { -#if __SSE4_1__ - return _mm_ceil_ps(x); -#endif // __SSE4_1__ - - // Use negative zero as the sign bit mask. - const __m128 magic_negative_zero = _mm_set_ps1(-0.0f); - - // The smallest float number that have no fractional part. (2^23) - const __m128 magic_smallest_no_fraction = _mm_set_ps1(8388608.0f); - - // absolute = abs(x); - __m128 absolute = _mm_andnot_ps(magic_negative_zero, x); - - // negative_mask = magic_negative_zero && x; - __m128 negative_mask = _mm_and_ps(magic_negative_zero, x); - - // no_fraction = (magic_smallest_no_fraction < absolute); - __m128 no_fraction = _mm_cmplt_ps(magic_smallest_no_fraction, absolute); - - // truncated = static_cast(static_cast(absolute)); - __m128 truncated = _mm_cvtepi32_ps(_mm_cvttps_epi32(absolute)); - - // truncated_with_sign = (truncated || negative_mask); - __m128 truncated_with_sign = _mm_or_ps(truncated, negative_mask); - - // positive_fix = ((x > -0.0f) && (x > truncated_with_sign) ? -1.0f : 0.0f); - __m128 positive_fix = _mm_and_ps( - _mm_and_ps( - _mm_cmpgt_ps(x, magic_negative_zero), - _mm_cmpgt_ps(x, truncated_with_sign)), - _mm_set_ps1(-1.0f)); - - // fixed_result = truncated_with_sign - positive_fix; - __m128 fixed_result = _mm_sub_ps(truncated_with_sign, positive_fix); - - // return ((x && no_fraction) || (!no_fraction && fixed_result)); - return _mm_or_ps( - _mm_and_ps(x, no_fraction), - _mm_andnot_ps(no_fraction, fixed_result)); + return ceil_ps(x); } #if __AVX__ __m256 func_pack8(const __m256& x) const