Browse Source

Add AVX512 support to DDOT

now that it's written in C + intrinsics it's easy to add AVX512 support
for DDOT
pull/1712/head
Arjan van de Ven 7 years ago
parent
commit
34d63df4b3
1 changed files with 32 additions and 5 deletions
  1. +32
    -5
      kernel/x86_64/ddot_microk_haswell-2.c

+ 32
- 5
kernel/x86_64/ddot_microk_haswell-2.c View File

@@ -26,7 +26,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/

/* Ensure that the compiler knows how to generate AVX2 instructions if it doesn't already */
#ifndef __AVX512CD_

#ifndef __AVX512CD__
#pragma GCC target("avx2,fma")
#endif

@@ -35,7 +36,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define HAVE_KERNEL_8 1

#include <immintrin.h>
static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y , FLOAT *dot) __attribute__ ((noinline));

static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot)
{
@@ -47,10 +47,37 @@ static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot)
accum_2 = _mm256_setzero_pd();
accum_3 = _mm256_setzero_pd();

#ifdef __AVX512CD__
__m512d accum_05, accum_15, accum_25, accum_35;
int n32;
n32 = n & (~31);

accum_05 = _mm512_setzero_pd();
accum_15 = _mm512_setzero_pd();
accum_25 = _mm512_setzero_pd();
accum_35 = _mm512_setzero_pd();

for (; i < n32; i += 32) {
accum_05 += _mm512_loadu_pd(&x[i+ 0]) * _mm512_loadu_pd(&y[i+ 0]);
accum_15 += _mm512_loadu_pd(&x[i+ 8]) * _mm512_loadu_pd(&y[i+ 8]);
accum_25 += _mm512_loadu_pd(&x[i+16]) * _mm512_loadu_pd(&y[i+16]);
accum_35 += _mm512_loadu_pd(&x[i+24]) * _mm512_loadu_pd(&y[i+24]);
}

/*
* we need to fold our 512 bit wide accumulator vectors into 256 bit wide vectors so that the AVX2 code
* below can continue using the intermediate results in its loop
*/
accum_0 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_05, 0), _mm512_extractf64x4_pd(accum_05, 1));
accum_1 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_15, 0), _mm512_extractf64x4_pd(accum_15, 1));
accum_2 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_25, 0), _mm512_extractf64x4_pd(accum_25, 1));
accum_3 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_35, 0), _mm512_extractf64x4_pd(accum_35, 1));

#endif
for (; i < n; i += 16) {
accum_0 += _mm256_loadu_pd(&x[i+ 0]) * _mm256_loadu_pd(&y[i+0]);
accum_1 += _mm256_loadu_pd(&x[i+ 4]) * _mm256_loadu_pd(&y[i+4]);
accum_2 += _mm256_loadu_pd(&x[i+ 8]) * _mm256_loadu_pd(&y[i+8]);
accum_0 += _mm256_loadu_pd(&x[i+ 0]) * _mm256_loadu_pd(&y[i+ 0]);
accum_1 += _mm256_loadu_pd(&x[i+ 4]) * _mm256_loadu_pd(&y[i+ 4]);
accum_2 += _mm256_loadu_pd(&x[i+ 8]) * _mm256_loadu_pd(&y[i+ 8]);
accum_3 += _mm256_loadu_pd(&x[i+12]) * _mm256_loadu_pd(&y[i+12]);
}



Loading…
Cancel
Save