| @@ -31,64 +31,8 @@ int Mish_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const | |||
| int h = bottom_top_blob.h; | |||
| int d = bottom_top_blob.d; | |||
| int channels = bottom_top_blob.c; | |||
| int size = w * h * d; | |||
| #if __SSE2__ | |||
| int elempack = bottom_top_blob.elempack; | |||
| #if __AVX__ | |||
| #if __AVX512F__ | |||
| if (elempack == 16) | |||
| { | |||
| Mat tmp; | |||
| convert_packing(bottom_top_blob, tmp, 8, opt); | |||
| forward_inplace(tmp, opt); | |||
| convert_packing(tmp, bottom_top_blob, 16, opt); | |||
| return 0; | |||
| } | |||
| #endif // __AVX512F__ | |||
| if (elempack == 8) | |||
| { | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = 0; q < channels; q++) | |||
| { | |||
| float* ptr = bottom_top_blob.channel(q); | |||
| for (int i = 0; i < size; i++) | |||
| { | |||
| __m256 _p = _mm256_loadu_ps(ptr); | |||
| _p = mish_avx(_p); | |||
| _mm256_storeu_ps(ptr, _p); | |||
| ptr += 8; | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| #endif // __AVX__ | |||
| if (elempack == 4) | |||
| { | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = 0; q < channels; q++) | |||
| { | |||
| float* ptr = bottom_top_blob.channel(q); | |||
| for (int i = 0; i < size; i++) | |||
| { | |||
| __m128 _p = _mm_loadu_ps(ptr); | |||
| _p = mish_sse(_p); | |||
| _mm_storeu_ps(ptr, _p); | |||
| ptr += 4; | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| #endif // __SSE2__ | |||
| int size = w * h * d * elempack; | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = 0; q < channels; q++) | |||
| @@ -98,6 +42,15 @@ int Mish_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const | |||
| int i = 0; | |||
| #if __SSE2__ | |||
| #if __AVX__ | |||
| #if __AVX512F__ | |||
| for (; i + 15 < size; i += 16) | |||
| { | |||
| __m512 _p = _mm512_loadu_ps(ptr); | |||
| _p = mish_avx512(_p); | |||
| _mm512_storeu_ps(ptr, _p); | |||
| ptr += 16; | |||
| } | |||
| #endif | |||
| for (; i + 7 < size; i += 8) | |||
| { | |||
| __m256 _p = _mm256_loadu_ps(ptr); | |||