Browse Source

fix scale avx512 (#4580)

tags/20230517
nihui GitHub 3 years ago
parent
commit
6987efd950
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 229 additions and 349 deletions
  1. +223
    -349
      src/layer/x86/scale_x86.cpp
  2. +6
    -0
      tests/test_scale.cpp

+ 223
- 349
src/layer/x86/scale_x86.cpp View File

@@ -37,6 +37,7 @@ int Scale_x86::forward_inplace(std::vector<Mat>& bottom_top_blobs, const Option&

const int w = bottom_top_blob.w;
const int h = bottom_top_blob.h;
const int d = bottom_top_blob.d;
const int channels = bottom_top_blob.c;
const int dims = bottom_top_blob.dims;

@@ -48,427 +49,300 @@ int Scale_x86::forward_inplace(std::vector<Mat>& bottom_top_blobs, const Option&
if (dims == 1)
{
float* ptr = (float*)bottom_top_blob;
int size = w * elempack;
const int size = w * elempack;

int remain = size;
if (bias_term)
{
int nn_size = 0;
int remain_size_start = 0;
#if __SSE2__
#if __AVX__
int nn = size >> 3;
remain = size & 7;
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < nn; i++)
{
__m256 _p = _mm256_loadu_ps(ptr + i * 8);
__m256 _s = _mm256_loadu_ps(scale + i * 8);
if (bias_term)
#if __AVX512F__
nn_size = (size - remain_size_start) / 16;
#pragma omp parallel for num_threads(opt.num_threads)
for (int ii = 0; ii < nn_size; ii++)
{
__m256 _bias = _mm256_loadu_ps(bias + i * 8);
_p = _mm256_comp_fmadd_ps(_p, _s, _bias);
int i = remain_size_start + ii * 16;
__m512 _p512 = _mm512_loadu_ps(ptr + i);
__m512 _s512 = _mm512_loadu_ps(scale + i);
__m512 _bias512 = _mm512_loadu_ps(bias + i);
_mm512_storeu_ps(ptr + i, _mm512_fmadd_ps(_p512, _s512, _bias512));
}
else
remain_size_start += nn_size * 16;
#endif // __AVX512F__
nn_size = (size - remain_size_start) / 8;
#pragma omp parallel for num_threads(opt.num_threads)
for (int ii = 0; ii < nn_size; ii++)
{
int i = remain_size_start + ii * 8;
__m256 _p256 = _mm256_loadu_ps(ptr + i);
__m256 _s256 = _mm256_loadu_ps(scale + i);
__m256 _bias256 = _mm256_loadu_ps(bias + i);
_mm256_storeu_ps(ptr + i, _mm256_comp_fmadd_ps(_p256, _s256, _bias256));
}
remain_size_start += nn_size * 8;
#endif // __AVX__
nn_size = (size - remain_size_start) / 4;
#pragma omp parallel for num_threads(opt.num_threads)
for (int ii = 0; ii < nn_size; ii++)
{
int i = remain_size_start + ii * 4;
__m128 _p128 = _mm_load_ps(ptr + i);
__m128 _s128 = _mm_load_ps(scale + i);
__m128 _bias128 = _mm_loadu_ps(bias + i);
_mm_store_ps(ptr + i, _mm_comp_fmadd_ps(_p128, _s128, _bias128));
}
remain_size_start += nn_size * 4;
#endif // __SSE2__
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = remain_size_start; i < size; i++)
{
_p = _mm256_mul_ps(_p, _s);
ptr[i] = ptr[i] * scale[i] + bias[i];
}
_mm256_storeu_ps(ptr + i * 8, _p);
}
#else
int nn = size >> 2;
remain = size & 3;
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < nn; i++)
else
{
__m128 _p = _mm_loadu_ps(ptr + i * 4);
__m128 _s = _mm_loadu_ps(scale + i * 4);
if (bias_term)
int nn_size = 0;
int remain_size_start = 0;
#if __SSE2__
#if __AVX__
#if __AVX512F__
nn_size = (size - remain_size_start) / 16;
#pragma omp parallel for num_threads(opt.num_threads)
for (int ii = 0; ii < nn_size; ii++)
{
__m128 _bias = _mm_loadu_ps(bias + i * 4);
_p = _mm_comp_fmadd_ps(_p, _s, _bias);
int i = remain_size_start + ii * 16;
__m512 _p512 = _mm512_loadu_ps(ptr + i);
__m512 _s512 = _mm512_loadu_ps(scale + i);
_mm512_storeu_ps(ptr + i, _mm512_mul_ps(_p512, _s512));
}
else
remain_size_start += nn_size * 16;
#endif // __AVX512F__
nn_size = (size - remain_size_start) / 8;
#pragma omp parallel for num_threads(opt.num_threads)
for (int ii = 0; ii < nn_size; ii++)
{
_p = _mm_mul_ps(_p, _s);
int i = remain_size_start + ii * 8;
__m256 _p256 = _mm256_loadu_ps(ptr + i);
__m256 _s256 = _mm256_loadu_ps(scale + i);
_mm256_storeu_ps(ptr + i, _mm256_mul_ps(_p256, _s256));
}
_mm_storeu_ps(ptr + i * 4, _p);
}
remain_size_start += nn_size * 8;
#endif // __AVX__
#endif // __SSE2__
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = size - remain; i < size; i++)
{
if (bias_term)
nn_size = (size - remain_size_start) / 4;
#pragma omp parallel for num_threads(opt.num_threads)
for (int ii = 0; ii < nn_size; ii++)
{
ptr[i] = ptr[i] * scale[i] + bias[i];
int i = remain_size_start + ii * 4;
__m128 _p128 = _mm_load_ps(ptr + i);
__m128 _s128 = _mm_load_ps(scale + i);
_mm_store_ps(ptr + i, _mm_mul_ps(_p128, _s128));
}
else
remain_size_start += nn_size * 4;
#endif // __SSE2__
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = remain_size_start; i < size; i++)
{
ptr[i] = ptr[i] * scale[i];
}
}

return 0;
}

#if __SSE2__
#if __AVX__
if (elempack == 8)
if (dims == 2)
{
if (dims == 2)
{
if (bias_term)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < h; i++)
{
float* ptr = bottom_top_blob.row(i);
__m256 _s = _mm256_loadu_ps((const float*)scale_blob + i * 8);
__m256 _bias = _mm256_loadu_ps((const float*)bias_data + i * 8);

for (int j = 0; j < w; j++)
{
__m256 _p = _mm256_loadu_ps(ptr);
_p = _mm256_comp_fmadd_ps(_p, _s, _bias);
_mm256_storeu_ps(ptr, _p);

ptr += 8;
}
}
}
else
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < h; i++)
{
float* ptr = bottom_top_blob.row(i);
__m256 _s = _mm256_loadu_ps((const float*)scale_blob + i * 8);

for (int j = 0; j < w; j++)
{
__m256 _p = _mm256_loadu_ps(ptr);
_p = _mm256_mul_ps(_p, _s);
_mm256_storeu_ps(ptr, _p);
const int size = w * elempack;

ptr += 8;
}
}
}
}

if (dims == 3)
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < h; i++)
{
int size = w * h;
float* ptr = bottom_top_blob.row(i);

float s = scale[i];
#if __SSE2__
__m128 _s128 = (elempack == 4) ? _mm_loadu_ps(scale + i * 4) : _mm_set1_ps(s);
#if __AVX__
__m256 _s256 = (elempack == 8) ? _mm256_loadu_ps(scale + i * 8) : _mm256_insertf128_ps(_mm256_castps128_ps256(_s128), _s128, 1);
#if __AVX512F__
__m512 _s512 = (elempack == 16) ? _mm512_loadu_ps(scale + i * 16) : _mm512_insertf32x8(_mm512_castps256_ps512(_s256), _s256, 1);
#endif // __AVX512F__
#endif // __AVX__
#endif // __SSE2__

if (bias_term)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
{
float* ptr = bottom_top_blob.channel(q);
__m256 _s = _mm256_loadu_ps((const float*)scale_blob + q * 8);
__m256 _bias = _mm256_loadu_ps((const float*)bias_data + q * 8);

for (int i = 0; i < size; i++)
{
__m256 _p = _mm256_loadu_ps(ptr);
_p = _mm256_comp_fmadd_ps(_p, _s, _bias);
_mm256_storeu_ps(ptr, _p);
float b = bias[i];
#if __SSE2__
__m128 _b128 = (elempack == 4) ? _mm_loadu_ps(bias + i * 4) : _mm_set1_ps(b);
#if __AVX__
__m256 _b256 = (elempack == 8) ? _mm256_loadu_ps(bias + i * 8) : _mm256_insertf128_ps(_mm256_castps128_ps256(_b128), _b128, 1);
#if __AVX512F__
__m512 _b512 = (elempack == 16) ? _mm512_loadu_ps(bias + i * 16) : _mm512_insertf32x8(_mm512_castps256_ps512(_b256), _b256, 1);
#endif // __AVX512F__
#endif // __AVX__
#endif // __SSE2__

ptr += 8;
}
int j = 0;
#if __SSE2__
#if __AVX__
#if __AVX512F__
for (; j + 15 < size; j += 16)
{
__m512 _p512 = _mm512_loadu_ps(ptr);
_mm512_storeu_ps(ptr, _mm512_fmadd_ps(_p512, _s512, _b512));
ptr += 16;
}
}
else
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
#endif // __AVX512F__
for (; j + 7 < size; j += 8)
{
float* ptr = bottom_top_blob.channel(q);
__m256 _s = _mm256_loadu_ps((const float*)scale_blob + q * 8);

for (int i = 0; i < size; i++)
{
__m256 _p = _mm256_loadu_ps(ptr);
_p = _mm256_mul_ps(_p, _s);
_mm256_storeu_ps(ptr, _p);

ptr += 8;
}
__m256 _p256 = _mm256_loadu_ps(ptr);
_mm256_storeu_ps(ptr, _mm256_comp_fmadd_ps(_p256, _s256, _b256));
ptr += 8;
}
}
}
return 0;
}
#endif // __AVX__

if (elempack == 4)
{
if (dims == 2)
{
if (bias_term)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < h; i++)
for (; j + 3 < size; j += 4)
{
float* ptr = bottom_top_blob.row(i);
__m128 _s = _mm_loadu_ps((const float*)scale_blob + i * 4);
__m128 _bias = _mm_loadu_ps((const float*)bias_data + i * 4);

for (int j = 0; j < w; j++)
{
__m128 _p = _mm_loadu_ps(ptr);
_p = _mm_add_ps(_mm_mul_ps(_p, _s), _bias);
_mm_storeu_ps(ptr, _p);

ptr += 4;
}
__m128 _p128 = _mm_loadu_ps(ptr);
_mm_storeu_ps(ptr, _mm_comp_fmadd_ps(_p128, _s128, _b128));
ptr += 4;
}
#endif // __SSE__
for (; j < size; j++)
{
*ptr = *ptr * s + b;
ptr++;
}
}
else
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < h; i++)
int j = 0;
#if __SSE2__
#if __AVX__
#if __AVX512F__
for (; j + 15 < size; j += 16)
{
float* ptr = bottom_top_blob.row(i);
__m128 _s = _mm_loadu_ps((const float*)scale_blob + i * 4);

for (int j = 0; j < w; j++)
{
__m128 _p = _mm_loadu_ps(ptr);
_p = _mm_mul_ps(_p, _s);
_mm_storeu_ps(ptr, _p);

ptr += 4;
}
__m512 _p512 = _mm512_loadu_ps(ptr);
_mm512_storeu_ps(ptr, _mm512_mul_ps(_p512, _s512));
ptr += 16;
}
}
}

if (dims == 3)
{
int size = w * h;

if (bias_term)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
#endif // __AVX512F__
for (; j + 7 < size; j += 8)
{
float* ptr = bottom_top_blob.channel(q);
__m128 _s = _mm_loadu_ps((const float*)scale_blob + q * 4);
__m128 _bias = _mm_loadu_ps((const float*)bias_data + q * 4);

for (int i = 0; i < size; i++)
{
__m128 _p = _mm_loadu_ps(ptr);
_p = _mm_add_ps(_mm_mul_ps(_p, _s), _bias);
_mm_storeu_ps(ptr, _p);

ptr += 4;
}
__m256 _p256 = _mm256_loadu_ps(ptr);
_mm256_storeu_ps(ptr, _mm256_mul_ps(_p256, _s256));
ptr += 8;
}
}
else
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
#endif // __AVX__
for (; j + 3 < size; j += 4)
{
float* ptr = bottom_top_blob.channel(q);
__m128 _s = _mm_loadu_ps((const float*)scale_blob + q * 4);

for (int i = 0; i < size; i++)
{
__m128 _p = _mm_loadu_ps(ptr);
_p = _mm_mul_ps(_p, _s);
_mm_storeu_ps(ptr, _p);

ptr += 4;
}
__m128 _p128 = _mm_loadu_ps(ptr);
_mm_storeu_ps(ptr, _mm_mul_ps(_p128, _s128));
ptr += 4;
}
#endif // __SSE__
for (; j < size; j++)
{
*ptr = *ptr * s;
ptr++;
}
}
}
}
#endif // __SSE2__

if (elempack == 1)
if (dims == 3 || dims == 4)
{
if (dims == 2)
{
int size = w;
if (bias_term)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < h; i++)
{
float* ptr = bottom_top_blob.row(i);
const int size = w * h * d * elempack;

float s = scale_blob[i];
float bias = bias_data[i];
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
{
float* ptr = bottom_top_blob.channel(q);

int j = 0;
float s = scale[q];
#if __SSE2__
__m128 _s128 = (elempack == 4) ? _mm_loadu_ps(scale + q * 4) : _mm_set1_ps(s);
#if __AVX__
__m256 _s = _mm256_set1_ps(s);
__m256 _bias = _mm256_set1_ps(bias);

for (; j + 7 < size; j += 8)
{
__m256 _p = _mm256_loadu_ps(ptr);
_p = _mm256_comp_fmadd_ps(_p, _s, _bias);
_mm256_storeu_ps(ptr, _p);

ptr += 8;
}
#else
__m128 _s = _mm_set1_ps(s);
__m128 _bias = _mm_set1_ps(bias);

for (; j + 3 < size; j += 4)
{
__m128 _p = _mm_loadu_ps(ptr);
_p = _mm_comp_fmadd_ps(_p, _s, _bias);
_mm_storeu_ps(ptr, _p);

ptr += 4;
}
__m256 _s256 = (elempack == 8) ? _mm256_loadu_ps(scale + q * 8) : _mm256_insertf128_ps(_mm256_castps128_ps256(_s128), _s128, 1);
#if __AVX512F__
__m512 _s512 = (elempack == 16) ? _mm512_loadu_ps(scale + q * 16) : _mm512_insertf32x8(_mm512_castps256_ps512(_s256), _s256, 1);
#endif // __AVX512F__
#endif // __AVX__
#endif // __SSE2__

for (; j < size; j++)
{
*ptr = *ptr * s + bias;

ptr++;
}
}
}
else
if (bias_term)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < h; i++)
{
float* ptr = bottom_top_blob.row(i);

float s = scale_blob[i];

int j = 0;
float b = bias[q];
#if __SSE2__
__m128 _b128 = (elempack == 4) ? _mm_loadu_ps(bias + q * 4) : _mm_set1_ps(b);
#if __AVX__
__m256 _s = _mm256_set1_ps(s);

for (; j + 7 < size; j += 8)
{
__m256 _p = _mm256_loadu_ps(ptr);
_p = _mm256_mul_ps(_p, _s);
_mm256_storeu_ps(ptr, _p);

ptr += 8;
}
#else
__m128 _s = _mm_set1_ps(s);

for (; j + 3 < size; j += 4)
{
__m128 _p = _mm_loadu_ps(ptr);
_p = _mm_mul_ps(_p, _s);
_mm_storeu_ps(ptr, _p);

ptr += 4;
}
__m256 _b256 = (elempack == 8) ? _mm256_loadu_ps(bias + q * 8) : _mm256_insertf128_ps(_mm256_castps128_ps256(_b128), _b128, 1);
#if __AVX512F__
__m512 _b512 = (elempack == 16) ? _mm512_loadu_ps(bias + q * 16) : _mm512_insertf32x8(_mm512_castps256_ps512(_b256), _b256, 1);
#endif // __AVX512F__
#endif // __AVX__
#endif // __SSE2__

for (; j < size; j++)
{
*ptr *= s;

ptr++;
}
}
}
}

if (dims == 3)
{
int size = w * h;
if (bias_term)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < channels; i++)
{
float* ptr = bottom_top_blob.channel(i);

float s = scale_blob[i];

int j = 0;
int i = 0;
#if __SSE2__
#if __AVX__
__m256 _s256 = _mm256_set1_ps(s);
__m256 _bias256 = _mm256_set1_ps(bias_data[i]);
for (; j + 7 < size; j += 8)
{
__m256 _p = _mm256_loadu_ps(ptr);
_p = _mm256_comp_fmadd_ps(_p, _s256, _bias256);
_mm256_storeu_ps(ptr, _p);

ptr += 8;
}
#if __AVX512F__
for (; i + 15 < size; i += 16)
{
__m512 _p512 = _mm512_loadu_ps(ptr);
_mm512_storeu_ps(ptr, _mm512_fmadd_ps(_p512, _s512, _b512));
ptr += 16;
}
#endif // __AVX512F__
for (; i + 7 < size; i += 8)
{
__m256 _p256 = _mm256_loadu_ps(ptr);
_mm256_storeu_ps(ptr, _mm256_comp_fmadd_ps(_p256, _s256, _b256));
ptr += 8;
}
#endif // __AVX__
__m128 _s128 = _mm_set1_ps(s);
__m128 _bias128 = _mm_set1_ps(bias_data[i]);
for (; j < size; j += 4)
{
__m128 _p = _mm_load_ps(ptr);
_p = _mm_comp_fmadd_ps(_p, _s128, _bias128);
_mm_storeu_ps(ptr, _p);

ptr += 4;
}
#endif // __SSE2__

for (; j < size; j++)
{
*ptr = *ptr * s + bias_data[i];
ptr++;
}
for (; i + 3 < size; i += 4)
{
__m128 _p128 = _mm_loadu_ps(ptr);
_mm_storeu_ps(ptr, _mm_comp_fmadd_ps(_p128, _s128, _b128));
ptr += 4;
}
#endif // __SSE__
for (; i < size; i++)
{
*ptr = *ptr * s + b;
ptr++;
}
}
else
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < channels; i++)
{
float* ptr = bottom_top_blob.channel(i);

float s = scale_blob[i];

int j = 0;
int i = 0;
#if __SSE2__
#if __AVX__
__m256 _s256 = _mm256_set1_ps(s);
for (; j + 7 < size; j += 8)
{
__m256 _p = _mm256_loadu_ps(ptr);
_p = _mm256_mul_ps(_p, _s256);
_mm256_storeu_ps(ptr, _p);

ptr += 8;
}
#if __AVX512F__
for (; i + 15 < size; i += 16)
{
__m512 _p512 = _mm512_loadu_ps(ptr);
_mm512_storeu_ps(ptr, _mm512_mul_ps(_p512, _s512));
ptr += 16;
}
#endif // __AVX512F__
for (; i + 7 < size; i += 8)
{
__m256 _p256 = _mm256_loadu_ps(ptr);
_mm256_storeu_ps(ptr, _mm256_mul_ps(_p256, _s256));
ptr += 8;
}
#endif // __AVX__

__m128 _s128 = _mm_set1_ps(s);
for (; j < size; j += 4)
{
__m128 _p = _mm_load_ps(ptr);
_p = _mm_mul_ps(_p, _s128);
_mm_storeu_ps(ptr, _p);

ptr += 4;
}
#endif // __SSE2__

for (; j < size; j++)
{
*ptr *= s;
ptr++;
}
for (; i + 3 < size; i += 4)
{
__m128 _p128 = _mm_loadu_ps(ptr);
_mm_storeu_ps(ptr, _mm_mul_ps(_p128, _s128));
ptr += 4;
}
#endif // __SSE__
for (; i < size; i++)
{
*ptr = *ptr * s;
ptr++;
}
}
}


+ 6
- 0
tests/test_scale.cpp View File

@@ -68,6 +68,8 @@ static int test_scale_attention(const ncnn::Mat& a)
static int test_scale_0()
{
return 0
|| test_scale(RandomMat(5, 3, 48), 0)
|| test_scale(RandomMat(5, 3, 48), 1)
|| test_scale(RandomMat(5, 7, 24), 0)
|| test_scale(RandomMat(5, 7, 24), 1)
|| test_scale(RandomMat(7, 9, 12), 0)
@@ -79,6 +81,8 @@ static int test_scale_0()
static int test_scale_1()
{
return 0
|| test_scale(RandomMat(13, 48), 0)
|| test_scale(RandomMat(13, 48), 1)
|| test_scale(RandomMat(15, 24), 0)
|| test_scale(RandomMat(15, 24), 1)
|| test_scale(RandomMat(17, 12), 0)
@@ -101,6 +105,7 @@ static int test_scale_2()
static int test_scale_3()
{
return 0
|| test_scale_attention(RandomMat(5, 6, 48))
|| test_scale_attention(RandomMat(5, 7, 24))
|| test_scale_attention(RandomMat(7, 9, 12))
|| test_scale_attention(RandomMat(3, 5, 13));
@@ -109,6 +114,7 @@ static int test_scale_3()
static int test_scale_4()
{
return 0
|| test_scale_attention(RandomMat(25, 48))
|| test_scale_attention(RandomMat(15, 24))
|| test_scale_attention(RandomMat(17, 12))
|| test_scale_attention(RandomMat(19, 15));


Loading…
Cancel
Save