Browse Source

Optimize gemv_n_sve kernel

tags/v0.3.30
manjam01 11 months ago
parent
commit
5c4e38ab17
2 changed files with 72 additions and 13 deletions
  1. +1
    -1
      kernel/arm64/KERNEL.ARMV8SVE
  2. +71
    -12
      kernel/arm64/gemv_n_sve.c

+ 1
- 1
kernel/arm64/KERNEL.ARMV8SVE View File

@@ -74,7 +74,7 @@ DSCALKERNEL = scal.S
CSCALKERNEL = zscal.S
ZSCALKERNEL = zscal.S

SGEMVNKERNEL = gemv_n.S
SGEMVNKERNEL = gemv_n_sve.c
DGEMVNKERNEL = gemv_n.S
CGEMVNKERNEL = zgemv_n.S
ZGEMVNKERNEL = zgemv_n.S


+ 71
- 12
kernel/arm64/gemv_n_sve.c View File

@@ -1,5 +1,5 @@
/***************************************************************************
Copyright (c) 2024, The OpenBLAS Project
Copyright (c) 2024-2025, The OpenBLAS Project
All rights reserved.

Redistribution and use in source and binary forms, with or without
@@ -59,23 +59,82 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
a_ptr = a;

if (inc_y == 1) {
BLASLONG width = n / 3;
uint64_t sve_size = SV_COUNT();
for (j = 0; j < n; j++) {
SV_TYPE temp_vec = SV_DUP(alpha * x[ix]);
i = 0;
svbool_t pg = SV_WHILE(i, m);
while (svptest_any(SV_TRUE(), pg)) {
SV_TYPE a_vec = svld1(pg, a_ptr + i);
svbool_t pg_true = SV_TRUE();
svbool_t pg = SV_WHILE(0, m % sve_size);

FLOAT *a0_ptr = a + lda * width * 0;
FLOAT *a1_ptr = a + lda * width * 1;
FLOAT *a2_ptr = a + lda * width * 2;

for (j = 0; j < width; j++) {
for (i = 0; (i + sve_size - 1) < m; i += sve_size) {
ix = j * inc_x;

SV_TYPE x0_vec = SV_DUP(alpha * x[ix + (inc_x * width * 0)]);
SV_TYPE x1_vec = SV_DUP(alpha * x[ix + (inc_x * width * 1)]);
SV_TYPE x2_vec = SV_DUP(alpha * x[ix + (inc_x * width * 2)]);

SV_TYPE a00_vec = svld1(pg_true, a0_ptr + i);
SV_TYPE a01_vec = svld1(pg_true, a1_ptr + i);
SV_TYPE a02_vec = svld1(pg_true, a2_ptr + i);

SV_TYPE y_vec = svld1(pg_true, y + i);
y_vec = svmla_lane(y_vec, a00_vec, x0_vec, 0);
y_vec = svmla_lane(y_vec, a01_vec, x1_vec, 0);
y_vec = svmla_lane(y_vec, a02_vec, x2_vec, 0);

svst1(pg_true, y + i, y_vec);
}

if (i < m) {
SV_TYPE x0_vec = SV_DUP(alpha * x[ix + (inc_x * width * 0)]);
SV_TYPE x1_vec = SV_DUP(alpha * x[ix + (inc_x * width * 1)]);
SV_TYPE x2_vec = SV_DUP(alpha * x[ix + (inc_x * width * 2)]);

SV_TYPE a00_vec = svld1(pg, a0_ptr + i);
SV_TYPE a01_vec = svld1(pg, a1_ptr + i);
SV_TYPE a02_vec = svld1(pg, a2_ptr + i);

SV_TYPE y_vec = svld1(pg, y + i);
y_vec = svmla_x(pg, y_vec, temp_vec, a_vec);
y_vec = svmla_m(pg, y_vec, a00_vec, x0_vec);
y_vec = svmla_m(pg, y_vec, a01_vec, x1_vec);
y_vec = svmla_m(pg, y_vec, a02_vec, x2_vec);

ix += inc_x;

svst1(pg, y + i, y_vec);
i += sve_size;
pg = SV_WHILE(i, m);
}

a0_ptr += lda;
a1_ptr += lda;
a2_ptr += lda;
}

a_ptr = a2_ptr;
for (j = width * 3; j < n; j++) {
ix = j * inc_x;
for (i = 0; (i + sve_size - 1) < m; i += sve_size) {
SV_TYPE y_vec = svld1(pg_true, y + i);
SV_TYPE x_vec = SV_DUP(alpha * x[(ix)]);
SV_TYPE a_vec = svld1(pg_true, a_ptr + i);
y_vec = svmla_x(pg_true, y_vec, a_vec, x_vec);
svst1(pg_true, y + i, y_vec);
}

if (i < m) {
SV_TYPE y_vec = svld1(pg, y + i);
SV_TYPE x_vec = SV_DUP(alpha * x[(ix)]);
SV_TYPE a_vec = svld1(pg, a_ptr + i);
y_vec = svmla_m(pg, y_vec, a_vec, x_vec);
svst1(pg, y + i, y_vec);
}

a_ptr += lda;
ix += inc_x;
}
return(0);
return (0);
}

for (j = 0; j < n; j++) {
@@ -89,4 +148,4 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
ix += inc_x;
}
return (0);
}
}

Loading…
Cancel
Save