diff --git a/src/layer/x86/mish_x86.cpp b/src/layer/x86/mish_x86.cpp index e55a5e1f8..90ce135c1 100644 --- a/src/layer/x86/mish_x86.cpp +++ b/src/layer/x86/mish_x86.cpp @@ -31,64 +31,8 @@ int Mish_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const int h = bottom_top_blob.h; int d = bottom_top_blob.d; int channels = bottom_top_blob.c; - int size = w * h * d; -#if __SSE2__ int elempack = bottom_top_blob.elempack; - -#if __AVX__ -#if __AVX512F__ - if (elempack == 16) - { - Mat tmp; - convert_packing(bottom_top_blob, tmp, 8, opt); - - forward_inplace(tmp, opt); - - convert_packing(tmp, bottom_top_blob, 16, opt); - - return 0; - } -#endif // __AVX512F__ - - if (elempack == 8) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - float* ptr = bottom_top_blob.channel(q); - - for (int i = 0; i < size; i++) - { - __m256 _p = _mm256_loadu_ps(ptr); - _p = mish_avx(_p); - _mm256_storeu_ps(ptr, _p); - ptr += 8; - } - } - - return 0; - } -#endif // __AVX__ - - if (elempack == 4) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - float* ptr = bottom_top_blob.channel(q); - - for (int i = 0; i < size; i++) - { - __m128 _p = _mm_loadu_ps(ptr); - _p = mish_sse(_p); - _mm_storeu_ps(ptr, _p); - ptr += 4; - } - } - - return 0; - } -#endif // __SSE2__ + int size = w * h * d * elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -98,6 +42,15 @@ int Mish_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const int i = 0; #if __SSE2__ #if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + _p = mish_avx512(_p); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + } +#endif for (; i + 7 < size; i += 8) { __m256 _p = _mm256_loadu_ps(ptr);