diff --git a/src/layer/x86/scale_x86.cpp b/src/layer/x86/scale_x86.cpp index 657c36ffd..868b1c53c 100644 --- a/src/layer/x86/scale_x86.cpp +++ b/src/layer/x86/scale_x86.cpp @@ -37,6 +37,7 @@ int Scale_x86::forward_inplace(std::vector& bottom_top_blobs, const Option& const int w = bottom_top_blob.w; const int h = bottom_top_blob.h; + const int d = bottom_top_blob.d; const int channels = bottom_top_blob.c; const int dims = bottom_top_blob.dims; @@ -48,427 +49,300 @@ int Scale_x86::forward_inplace(std::vector& bottom_top_blobs, const Option& if (dims == 1) { float* ptr = (float*)bottom_top_blob; - int size = w * elempack; + const int size = w * elempack; - int remain = size; + if (bias_term) + { + int nn_size = 0; + int remain_size_start = 0; #if __SSE2__ #if __AVX__ - int nn = size >> 3; - remain = size & 7; - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < nn; i++) - { - __m256 _p = _mm256_loadu_ps(ptr + i * 8); - __m256 _s = _mm256_loadu_ps(scale + i * 8); - if (bias_term) +#if __AVX512F__ + nn_size = (size - remain_size_start) / 16; + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_size; ii++) { - __m256 _bias = _mm256_loadu_ps(bias + i * 8); - _p = _mm256_comp_fmadd_ps(_p, _s, _bias); + int i = remain_size_start + ii * 16; + __m512 _p512 = _mm512_loadu_ps(ptr + i); + __m512 _s512 = _mm512_loadu_ps(scale + i); + __m512 _bias512 = _mm512_loadu_ps(bias + i); + _mm512_storeu_ps(ptr + i, _mm512_fmadd_ps(_p512, _s512, _bias512)); } - else + remain_size_start += nn_size * 16; +#endif // __AVX512F__ + nn_size = (size - remain_size_start) / 8; + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 8; + __m256 _p256 = _mm256_loadu_ps(ptr + i); + __m256 _s256 = _mm256_loadu_ps(scale + i); + __m256 _bias256 = _mm256_loadu_ps(bias + i); + _mm256_storeu_ps(ptr + i, _mm256_comp_fmadd_ps(_p256, _s256, _bias256)); + } + remain_size_start += nn_size * 8; +#endif // __AVX__ + nn_size = (size - remain_size_start) / 4; + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 4; + __m128 _p128 = _mm_load_ps(ptr + i); + __m128 _s128 = _mm_load_ps(scale + i); + __m128 _bias128 = _mm_loadu_ps(bias + i); + _mm_store_ps(ptr + i, _mm_comp_fmadd_ps(_p128, _s128, _bias128)); + } + remain_size_start += nn_size * 4; +#endif // __SSE2__ + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = remain_size_start; i < size; i++) { - _p = _mm256_mul_ps(_p, _s); + ptr[i] = ptr[i] * scale[i] + bias[i]; } - _mm256_storeu_ps(ptr + i * 8, _p); } -#else - int nn = size >> 2; - remain = size & 3; - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < nn; i++) + else { - __m128 _p = _mm_loadu_ps(ptr + i * 4); - __m128 _s = _mm_loadu_ps(scale + i * 4); - if (bias_term) + int nn_size = 0; + int remain_size_start = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + nn_size = (size - remain_size_start) / 16; + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_size; ii++) { - __m128 _bias = _mm_loadu_ps(bias + i * 4); - _p = _mm_comp_fmadd_ps(_p, _s, _bias); + int i = remain_size_start + ii * 16; + __m512 _p512 = _mm512_loadu_ps(ptr + i); + __m512 _s512 = _mm512_loadu_ps(scale + i); + _mm512_storeu_ps(ptr + i, _mm512_mul_ps(_p512, _s512)); } - else + remain_size_start += nn_size * 16; +#endif // __AVX512F__ + nn_size = (size - remain_size_start) / 8; + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_size; ii++) { - _p = _mm_mul_ps(_p, _s); + int i = remain_size_start + ii * 8; + __m256 _p256 = _mm256_loadu_ps(ptr + i); + __m256 _s256 = _mm256_loadu_ps(scale + i); + _mm256_storeu_ps(ptr + i, _mm256_mul_ps(_p256, _s256)); } - _mm_storeu_ps(ptr + i * 4, _p); - } + remain_size_start += nn_size * 8; #endif // __AVX__ -#endif // __SSE2__ - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = size - remain; i < size; i++) - { - if (bias_term) + nn_size = (size - remain_size_start) / 4; + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_size; ii++) { - ptr[i] = ptr[i] * scale[i] + bias[i]; + int i = remain_size_start + ii * 4; + __m128 _p128 = _mm_load_ps(ptr + i); + __m128 _s128 = _mm_load_ps(scale + i); + _mm_store_ps(ptr + i, _mm_mul_ps(_p128, _s128)); } - else + remain_size_start += nn_size * 4; +#endif // __SSE2__ + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = remain_size_start; i < size; i++) { ptr[i] = ptr[i] * scale[i]; } } - - return 0; } -#if __SSE2__ -#if __AVX__ - if (elempack == 8) + if (dims == 2) { - if (dims == 2) - { - if (bias_term) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < h; i++) - { - float* ptr = bottom_top_blob.row(i); - __m256 _s = _mm256_loadu_ps((const float*)scale_blob + i * 8); - __m256 _bias = _mm256_loadu_ps((const float*)bias_data + i * 8); - - for (int j = 0; j < w; j++) - { - __m256 _p = _mm256_loadu_ps(ptr); - _p = _mm256_comp_fmadd_ps(_p, _s, _bias); - _mm256_storeu_ps(ptr, _p); - - ptr += 8; - } - } - } - else - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < h; i++) - { - float* ptr = bottom_top_blob.row(i); - __m256 _s = _mm256_loadu_ps((const float*)scale_blob + i * 8); - - for (int j = 0; j < w; j++) - { - __m256 _p = _mm256_loadu_ps(ptr); - _p = _mm256_mul_ps(_p, _s); - _mm256_storeu_ps(ptr, _p); + const int size = w * elempack; - ptr += 8; - } - } - } - } - - if (dims == 3) + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) { - int size = w * h; + float* ptr = bottom_top_blob.row(i); + + float s = scale[i]; +#if __SSE2__ + __m128 _s128 = (elempack == 4) ? _mm_loadu_ps(scale + i * 4) : _mm_set1_ps(s); +#if __AVX__ + __m256 _s256 = (elempack == 8) ? _mm256_loadu_ps(scale + i * 8) : _mm256_insertf128_ps(_mm256_castps128_ps256(_s128), _s128, 1); +#if __AVX512F__ + __m512 _s512 = (elempack == 16) ? _mm512_loadu_ps(scale + i * 16) : _mm512_insertf32x8(_mm512_castps256_ps512(_s256), _s256, 1); +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ if (bias_term) { - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - float* ptr = bottom_top_blob.channel(q); - __m256 _s = _mm256_loadu_ps((const float*)scale_blob + q * 8); - __m256 _bias = _mm256_loadu_ps((const float*)bias_data + q * 8); - - for (int i = 0; i < size; i++) - { - __m256 _p = _mm256_loadu_ps(ptr); - _p = _mm256_comp_fmadd_ps(_p, _s, _bias); - _mm256_storeu_ps(ptr, _p); + float b = bias[i]; +#if __SSE2__ + __m128 _b128 = (elempack == 4) ? _mm_loadu_ps(bias + i * 4) : _mm_set1_ps(b); +#if __AVX__ + __m256 _b256 = (elempack == 8) ? _mm256_loadu_ps(bias + i * 8) : _mm256_insertf128_ps(_mm256_castps128_ps256(_b128), _b128, 1); +#if __AVX512F__ + __m512 _b512 = (elempack == 16) ? _mm512_loadu_ps(bias + i * 16) : _mm512_insertf32x8(_mm512_castps256_ps512(_b256), _b256, 1); +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ - ptr += 8; - } + int j = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; j + 15 < size; j += 16) + { + __m512 _p512 = _mm512_loadu_ps(ptr); + _mm512_storeu_ps(ptr, _mm512_fmadd_ps(_p512, _s512, _b512)); + ptr += 16; } - } - else - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) +#endif // __AVX512F__ + for (; j + 7 < size; j += 8) { - float* ptr = bottom_top_blob.channel(q); - __m256 _s = _mm256_loadu_ps((const float*)scale_blob + q * 8); - - for (int i = 0; i < size; i++) - { - __m256 _p = _mm256_loadu_ps(ptr); - _p = _mm256_mul_ps(_p, _s); - _mm256_storeu_ps(ptr, _p); - - ptr += 8; - } + __m256 _p256 = _mm256_loadu_ps(ptr); + _mm256_storeu_ps(ptr, _mm256_comp_fmadd_ps(_p256, _s256, _b256)); + ptr += 8; } - } - } - return 0; - } #endif // __AVX__ - - if (elempack == 4) - { - if (dims == 2) - { - if (bias_term) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < h; i++) + for (; j + 3 < size; j += 4) { - float* ptr = bottom_top_blob.row(i); - __m128 _s = _mm_loadu_ps((const float*)scale_blob + i * 4); - __m128 _bias = _mm_loadu_ps((const float*)bias_data + i * 4); - - for (int j = 0; j < w; j++) - { - __m128 _p = _mm_loadu_ps(ptr); - _p = _mm_add_ps(_mm_mul_ps(_p, _s), _bias); - _mm_storeu_ps(ptr, _p); - - ptr += 4; - } + __m128 _p128 = _mm_loadu_ps(ptr); + _mm_storeu_ps(ptr, _mm_comp_fmadd_ps(_p128, _s128, _b128)); + ptr += 4; + } +#endif // __SSE__ + for (; j < size; j++) + { + *ptr = *ptr * s + b; + ptr++; } } else { - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < h; i++) + int j = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; j + 15 < size; j += 16) { - float* ptr = bottom_top_blob.row(i); - __m128 _s = _mm_loadu_ps((const float*)scale_blob + i * 4); - - for (int j = 0; j < w; j++) - { - __m128 _p = _mm_loadu_ps(ptr); - _p = _mm_mul_ps(_p, _s); - _mm_storeu_ps(ptr, _p); - - ptr += 4; - } + __m512 _p512 = _mm512_loadu_ps(ptr); + _mm512_storeu_ps(ptr, _mm512_mul_ps(_p512, _s512)); + ptr += 16; } - } - } - - if (dims == 3) - { - int size = w * h; - - if (bias_term) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) +#endif // __AVX512F__ + for (; j + 7 < size; j += 8) { - float* ptr = bottom_top_blob.channel(q); - __m128 _s = _mm_loadu_ps((const float*)scale_blob + q * 4); - __m128 _bias = _mm_loadu_ps((const float*)bias_data + q * 4); - - for (int i = 0; i < size; i++) - { - __m128 _p = _mm_loadu_ps(ptr); - _p = _mm_add_ps(_mm_mul_ps(_p, _s), _bias); - _mm_storeu_ps(ptr, _p); - - ptr += 4; - } + __m256 _p256 = _mm256_loadu_ps(ptr); + _mm256_storeu_ps(ptr, _mm256_mul_ps(_p256, _s256)); + ptr += 8; } - } - else - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) +#endif // __AVX__ + for (; j + 3 < size; j += 4) { - float* ptr = bottom_top_blob.channel(q); - __m128 _s = _mm_loadu_ps((const float*)scale_blob + q * 4); - - for (int i = 0; i < size; i++) - { - __m128 _p = _mm_loadu_ps(ptr); - _p = _mm_mul_ps(_p, _s); - _mm_storeu_ps(ptr, _p); - - ptr += 4; - } + __m128 _p128 = _mm_loadu_ps(ptr); + _mm_storeu_ps(ptr, _mm_mul_ps(_p128, _s128)); + ptr += 4; + } +#endif // __SSE__ + for (; j < size; j++) + { + *ptr = *ptr * s; + ptr++; } } } } -#endif // __SSE2__ - if (elempack == 1) + if (dims == 3 || dims == 4) { - if (dims == 2) - { - int size = w; - if (bias_term) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < h; i++) - { - float* ptr = bottom_top_blob.row(i); + const int size = w * h * d * elempack; - float s = scale_blob[i]; - float bias = bias_data[i]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); - int j = 0; + float s = scale[q]; #if __SSE2__ + __m128 _s128 = (elempack == 4) ? _mm_loadu_ps(scale + q * 4) : _mm_set1_ps(s); #if __AVX__ - __m256 _s = _mm256_set1_ps(s); - __m256 _bias = _mm256_set1_ps(bias); - - for (; j + 7 < size; j += 8) - { - __m256 _p = _mm256_loadu_ps(ptr); - _p = _mm256_comp_fmadd_ps(_p, _s, _bias); - _mm256_storeu_ps(ptr, _p); - - ptr += 8; - } -#else - __m128 _s = _mm_set1_ps(s); - __m128 _bias = _mm_set1_ps(bias); - - for (; j + 3 < size; j += 4) - { - __m128 _p = _mm_loadu_ps(ptr); - _p = _mm_comp_fmadd_ps(_p, _s, _bias); - _mm_storeu_ps(ptr, _p); - - ptr += 4; - } + __m256 _s256 = (elempack == 8) ? _mm256_loadu_ps(scale + q * 8) : _mm256_insertf128_ps(_mm256_castps128_ps256(_s128), _s128, 1); +#if __AVX512F__ + __m512 _s512 = (elempack == 16) ? _mm512_loadu_ps(scale + q * 16) : _mm512_insertf32x8(_mm512_castps256_ps512(_s256), _s256, 1); +#endif // __AVX512F__ #endif // __AVX__ #endif // __SSE2__ - for (; j < size; j++) - { - *ptr = *ptr * s + bias; - - ptr++; - } - } - } - else + if (bias_term) { - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < h; i++) - { - float* ptr = bottom_top_blob.row(i); - - float s = scale_blob[i]; - - int j = 0; + float b = bias[q]; #if __SSE2__ + __m128 _b128 = (elempack == 4) ? _mm_loadu_ps(bias + q * 4) : _mm_set1_ps(b); #if __AVX__ - __m256 _s = _mm256_set1_ps(s); - - for (; j + 7 < size; j += 8) - { - __m256 _p = _mm256_loadu_ps(ptr); - _p = _mm256_mul_ps(_p, _s); - _mm256_storeu_ps(ptr, _p); - - ptr += 8; - } -#else - __m128 _s = _mm_set1_ps(s); - - for (; j + 3 < size; j += 4) - { - __m128 _p = _mm_loadu_ps(ptr); - _p = _mm_mul_ps(_p, _s); - _mm_storeu_ps(ptr, _p); - - ptr += 4; - } + __m256 _b256 = (elempack == 8) ? _mm256_loadu_ps(bias + q * 8) : _mm256_insertf128_ps(_mm256_castps128_ps256(_b128), _b128, 1); +#if __AVX512F__ + __m512 _b512 = (elempack == 16) ? _mm512_loadu_ps(bias + q * 16) : _mm512_insertf32x8(_mm512_castps256_ps512(_b256), _b256, 1); +#endif // __AVX512F__ #endif // __AVX__ #endif // __SSE2__ - for (; j < size; j++) - { - *ptr *= s; - - ptr++; - } - } - } - } - - if (dims == 3) - { - int size = w * h; - if (bias_term) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < channels; i++) - { - float* ptr = bottom_top_blob.channel(i); - - float s = scale_blob[i]; - - int j = 0; + int i = 0; #if __SSE2__ #if __AVX__ - __m256 _s256 = _mm256_set1_ps(s); - __m256 _bias256 = _mm256_set1_ps(bias_data[i]); - for (; j + 7 < size; j += 8) - { - __m256 _p = _mm256_loadu_ps(ptr); - _p = _mm256_comp_fmadd_ps(_p, _s256, _bias256); - _mm256_storeu_ps(ptr, _p); - - ptr += 8; - } +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p512 = _mm512_loadu_ps(ptr); + _mm512_storeu_ps(ptr, _mm512_fmadd_ps(_p512, _s512, _b512)); + ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p256 = _mm256_loadu_ps(ptr); + _mm256_storeu_ps(ptr, _mm256_comp_fmadd_ps(_p256, _s256, _b256)); + ptr += 8; + } #endif // __AVX__ - __m128 _s128 = _mm_set1_ps(s); - __m128 _bias128 = _mm_set1_ps(bias_data[i]); - for (; j < size; j += 4) - { - __m128 _p = _mm_load_ps(ptr); - _p = _mm_comp_fmadd_ps(_p, _s128, _bias128); - _mm_storeu_ps(ptr, _p); - - ptr += 4; - } -#endif // __SSE2__ - - for (; j < size; j++) - { - *ptr = *ptr * s + bias_data[i]; - ptr++; - } + for (; i + 3 < size; i += 4) + { + __m128 _p128 = _mm_loadu_ps(ptr); + _mm_storeu_ps(ptr, _mm_comp_fmadd_ps(_p128, _s128, _b128)); + ptr += 4; + } +#endif // __SSE__ + for (; i < size; i++) + { + *ptr = *ptr * s + b; + ptr++; } } else { - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < channels; i++) - { - float* ptr = bottom_top_blob.channel(i); - - float s = scale_blob[i]; - - int j = 0; + int i = 0; #if __SSE2__ #if __AVX__ - __m256 _s256 = _mm256_set1_ps(s); - for (; j + 7 < size; j += 8) - { - __m256 _p = _mm256_loadu_ps(ptr); - _p = _mm256_mul_ps(_p, _s256); - _mm256_storeu_ps(ptr, _p); - - ptr += 8; - } +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p512 = _mm512_loadu_ps(ptr); + _mm512_storeu_ps(ptr, _mm512_mul_ps(_p512, _s512)); + ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p256 = _mm256_loadu_ps(ptr); + _mm256_storeu_ps(ptr, _mm256_mul_ps(_p256, _s256)); + ptr += 8; + } #endif // __AVX__ - - __m128 _s128 = _mm_set1_ps(s); - for (; j < size; j += 4) - { - __m128 _p = _mm_load_ps(ptr); - _p = _mm_mul_ps(_p, _s128); - _mm_storeu_ps(ptr, _p); - - ptr += 4; - } -#endif // __SSE2__ - - for (; j < size; j++) - { - *ptr *= s; - ptr++; - } + for (; i + 3 < size; i += 4) + { + __m128 _p128 = _mm_loadu_ps(ptr); + _mm_storeu_ps(ptr, _mm_mul_ps(_p128, _s128)); + ptr += 4; + } +#endif // __SSE__ + for (; i < size; i++) + { + *ptr = *ptr * s; + ptr++; } } } diff --git a/tests/test_scale.cpp b/tests/test_scale.cpp index bd114dd24..e4045a0e9 100644 --- a/tests/test_scale.cpp +++ b/tests/test_scale.cpp @@ -68,6 +68,8 @@ static int test_scale_attention(const ncnn::Mat& a) static int test_scale_0() { return 0 + || test_scale(RandomMat(5, 3, 48), 0) + || test_scale(RandomMat(5, 3, 48), 1) || test_scale(RandomMat(5, 7, 24), 0) || test_scale(RandomMat(5, 7, 24), 1) || test_scale(RandomMat(7, 9, 12), 0) @@ -79,6 +81,8 @@ static int test_scale_0() static int test_scale_1() { return 0 + || test_scale(RandomMat(13, 48), 0) + || test_scale(RandomMat(13, 48), 1) || test_scale(RandomMat(15, 24), 0) || test_scale(RandomMat(15, 24), 1) || test_scale(RandomMat(17, 12), 0) @@ -101,6 +105,7 @@ static int test_scale_2() static int test_scale_3() { return 0 + || test_scale_attention(RandomMat(5, 6, 48)) || test_scale_attention(RandomMat(5, 7, 24)) || test_scale_attention(RandomMat(7, 9, 12)) || test_scale_attention(RandomMat(3, 5, 13)); @@ -109,6 +114,7 @@ static int test_scale_3() static int test_scale_4() { return 0 + || test_scale_attention(RandomMat(25, 48)) || test_scale_attention(RandomMat(15, 24)) || test_scale_attention(RandomMat(17, 12)) || test_scale_attention(RandomMat(19, 15));