Browse Source

Optimize gemv_n_sve kernel

tags/v0.3.30
manjam01 11 months ago
parent
commit
5c4e38ab17
2 changed files with 72 additions and 13 deletions
  1. +1
    -1
      kernel/arm64/KERNEL.ARMV8SVE
  2. +71
    -12
      kernel/arm64/gemv_n_sve.c

+ 1
- 1
kernel/arm64/KERNEL.ARMV8SVE View File

@@ -74,7 +74,7 @@ DSCALKERNEL = scal.S
CSCALKERNEL = zscal.S CSCALKERNEL = zscal.S
ZSCALKERNEL = zscal.S ZSCALKERNEL = zscal.S


SGEMVNKERNEL = gemv_n.S
SGEMVNKERNEL = gemv_n_sve.c
DGEMVNKERNEL = gemv_n.S DGEMVNKERNEL = gemv_n.S
CGEMVNKERNEL = zgemv_n.S CGEMVNKERNEL = zgemv_n.S
ZGEMVNKERNEL = zgemv_n.S ZGEMVNKERNEL = zgemv_n.S


+ 71
- 12
kernel/arm64/gemv_n_sve.c View File

@@ -1,5 +1,5 @@
/*************************************************************************** /***************************************************************************
Copyright (c) 2024, The OpenBLAS Project
Copyright (c) 2024-2025, The OpenBLAS Project
All rights reserved. All rights reserved.


Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without
@@ -59,23 +59,82 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
a_ptr = a; a_ptr = a;


if (inc_y == 1) { if (inc_y == 1) {
BLASLONG width = n / 3;
uint64_t sve_size = SV_COUNT(); uint64_t sve_size = SV_COUNT();
for (j = 0; j < n; j++) {
SV_TYPE temp_vec = SV_DUP(alpha * x[ix]);
i = 0;
svbool_t pg = SV_WHILE(i, m);
while (svptest_any(SV_TRUE(), pg)) {
SV_TYPE a_vec = svld1(pg, a_ptr + i);
svbool_t pg_true = SV_TRUE();
svbool_t pg = SV_WHILE(0, m % sve_size);

FLOAT *a0_ptr = a + lda * width * 0;
FLOAT *a1_ptr = a + lda * width * 1;
FLOAT *a2_ptr = a + lda * width * 2;

for (j = 0; j < width; j++) {
for (i = 0; (i + sve_size - 1) < m; i += sve_size) {
ix = j * inc_x;

SV_TYPE x0_vec = SV_DUP(alpha * x[ix + (inc_x * width * 0)]);
SV_TYPE x1_vec = SV_DUP(alpha * x[ix + (inc_x * width * 1)]);
SV_TYPE x2_vec = SV_DUP(alpha * x[ix + (inc_x * width * 2)]);

SV_TYPE a00_vec = svld1(pg_true, a0_ptr + i);
SV_TYPE a01_vec = svld1(pg_true, a1_ptr + i);
SV_TYPE a02_vec = svld1(pg_true, a2_ptr + i);

SV_TYPE y_vec = svld1(pg_true, y + i);
y_vec = svmla_lane(y_vec, a00_vec, x0_vec, 0);
y_vec = svmla_lane(y_vec, a01_vec, x1_vec, 0);
y_vec = svmla_lane(y_vec, a02_vec, x2_vec, 0);

svst1(pg_true, y + i, y_vec);
}

if (i < m) {
SV_TYPE x0_vec = SV_DUP(alpha * x[ix + (inc_x * width * 0)]);
SV_TYPE x1_vec = SV_DUP(alpha * x[ix + (inc_x * width * 1)]);
SV_TYPE x2_vec = SV_DUP(alpha * x[ix + (inc_x * width * 2)]);

SV_TYPE a00_vec = svld1(pg, a0_ptr + i);
SV_TYPE a01_vec = svld1(pg, a1_ptr + i);
SV_TYPE a02_vec = svld1(pg, a2_ptr + i);

SV_TYPE y_vec = svld1(pg, y + i); SV_TYPE y_vec = svld1(pg, y + i);
y_vec = svmla_x(pg, y_vec, temp_vec, a_vec);
y_vec = svmla_m(pg, y_vec, a00_vec, x0_vec);
y_vec = svmla_m(pg, y_vec, a01_vec, x1_vec);
y_vec = svmla_m(pg, y_vec, a02_vec, x2_vec);

ix += inc_x;

svst1(pg, y + i, y_vec); svst1(pg, y + i, y_vec);
i += sve_size;
pg = SV_WHILE(i, m);
} }

a0_ptr += lda;
a1_ptr += lda;
a2_ptr += lda;
}

a_ptr = a2_ptr;
for (j = width * 3; j < n; j++) {
ix = j * inc_x;
for (i = 0; (i + sve_size - 1) < m; i += sve_size) {
SV_TYPE y_vec = svld1(pg_true, y + i);
SV_TYPE x_vec = SV_DUP(alpha * x[(ix)]);
SV_TYPE a_vec = svld1(pg_true, a_ptr + i);
y_vec = svmla_x(pg_true, y_vec, a_vec, x_vec);
svst1(pg_true, y + i, y_vec);
}

if (i < m) {
SV_TYPE y_vec = svld1(pg, y + i);
SV_TYPE x_vec = SV_DUP(alpha * x[(ix)]);
SV_TYPE a_vec = svld1(pg, a_ptr + i);
y_vec = svmla_m(pg, y_vec, a_vec, x_vec);
svst1(pg, y + i, y_vec);
}

a_ptr += lda; a_ptr += lda;
ix += inc_x; ix += inc_x;
} }
return(0);
return (0);
} }


for (j = 0; j < n; j++) { for (j = 0; j < n; j++) {
@@ -89,4 +148,4 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
ix += inc_x; ix += inc_x;
} }
return (0); return (0);
}
}

Loading…
Cancel
Save