Browse Source

Merge pull request #5157 from manaalmj/feature

Optimize gemv_n_sve kernel
tags/v0.3.30
Martin Kroeker GitHub 1 year ago
parent
commit
a3e7b16072
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 13 deletions
  1. +1
    -1
      kernel/arm64/KERNEL.ARMV8SVE
  2. +71
    -12
      kernel/arm64/gemv_n_sve.c

+ 1
- 1
kernel/arm64/KERNEL.ARMV8SVE View File

@@ -74,7 +74,7 @@ DSCALKERNEL = scal.S
CSCALKERNEL = zscal.S CSCALKERNEL = zscal.S
ZSCALKERNEL = zscal.S ZSCALKERNEL = zscal.S


SGEMVNKERNEL = gemv_n.S
SGEMVNKERNEL = gemv_n_sve.c
DGEMVNKERNEL = gemv_n.S DGEMVNKERNEL = gemv_n.S
CGEMVNKERNEL = zgemv_n.S CGEMVNKERNEL = zgemv_n.S
ZGEMVNKERNEL = zgemv_n.S ZGEMVNKERNEL = zgemv_n.S


+ 71
- 12
kernel/arm64/gemv_n_sve.c View File

@@ -1,5 +1,5 @@
/*************************************************************************** /***************************************************************************
Copyright (c) 2024, The OpenBLAS Project
Copyright (c) 2024-2025, The OpenBLAS Project
All rights reserved. All rights reserved.


Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without
@@ -59,23 +59,82 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
a_ptr = a; a_ptr = a;


if (inc_y == 1) { if (inc_y == 1) {
BLASLONG width = n / 3;
uint64_t sve_size = SV_COUNT(); uint64_t sve_size = SV_COUNT();
for (j = 0; j < n; j++) {
SV_TYPE temp_vec = SV_DUP(alpha * x[ix]);
i = 0;
svbool_t pg = SV_WHILE(i, m);
while (svptest_any(SV_TRUE(), pg)) {
SV_TYPE a_vec = svld1(pg, a_ptr + i);
svbool_t pg_true = SV_TRUE();
svbool_t pg = SV_WHILE(0, m % sve_size);

FLOAT *a0_ptr = a + lda * width * 0;
FLOAT *a1_ptr = a + lda * width * 1;
FLOAT *a2_ptr = a + lda * width * 2;

for (j = 0; j < width; j++) {
for (i = 0; (i + sve_size - 1) < m; i += sve_size) {
ix = j * inc_x;

SV_TYPE x0_vec = SV_DUP(alpha * x[ix + (inc_x * width * 0)]);
SV_TYPE x1_vec = SV_DUP(alpha * x[ix + (inc_x * width * 1)]);
SV_TYPE x2_vec = SV_DUP(alpha * x[ix + (inc_x * width * 2)]);

SV_TYPE a00_vec = svld1(pg_true, a0_ptr + i);
SV_TYPE a01_vec = svld1(pg_true, a1_ptr + i);
SV_TYPE a02_vec = svld1(pg_true, a2_ptr + i);

SV_TYPE y_vec = svld1(pg_true, y + i);
y_vec = svmla_lane(y_vec, a00_vec, x0_vec, 0);
y_vec = svmla_lane(y_vec, a01_vec, x1_vec, 0);
y_vec = svmla_lane(y_vec, a02_vec, x2_vec, 0);

svst1(pg_true, y + i, y_vec);
}

if (i < m) {
SV_TYPE x0_vec = SV_DUP(alpha * x[ix + (inc_x * width * 0)]);
SV_TYPE x1_vec = SV_DUP(alpha * x[ix + (inc_x * width * 1)]);
SV_TYPE x2_vec = SV_DUP(alpha * x[ix + (inc_x * width * 2)]);

SV_TYPE a00_vec = svld1(pg, a0_ptr + i);
SV_TYPE a01_vec = svld1(pg, a1_ptr + i);
SV_TYPE a02_vec = svld1(pg, a2_ptr + i);

SV_TYPE y_vec = svld1(pg, y + i); SV_TYPE y_vec = svld1(pg, y + i);
y_vec = svmla_x(pg, y_vec, temp_vec, a_vec);
y_vec = svmla_m(pg, y_vec, a00_vec, x0_vec);
y_vec = svmla_m(pg, y_vec, a01_vec, x1_vec);
y_vec = svmla_m(pg, y_vec, a02_vec, x2_vec);

ix += inc_x;

svst1(pg, y + i, y_vec); svst1(pg, y + i, y_vec);
i += sve_size;
pg = SV_WHILE(i, m);
} }

a0_ptr += lda;
a1_ptr += lda;
a2_ptr += lda;
}

a_ptr = a2_ptr;
for (j = width * 3; j < n; j++) {
ix = j * inc_x;
for (i = 0; (i + sve_size - 1) < m; i += sve_size) {
SV_TYPE y_vec = svld1(pg_true, y + i);
SV_TYPE x_vec = SV_DUP(alpha * x[(ix)]);
SV_TYPE a_vec = svld1(pg_true, a_ptr + i);
y_vec = svmla_x(pg_true, y_vec, a_vec, x_vec);
svst1(pg_true, y + i, y_vec);
}

if (i < m) {
SV_TYPE y_vec = svld1(pg, y + i);
SV_TYPE x_vec = SV_DUP(alpha * x[(ix)]);
SV_TYPE a_vec = svld1(pg, a_ptr + i);
y_vec = svmla_m(pg, y_vec, a_vec, x_vec);
svst1(pg, y + i, y_vec);
}

a_ptr += lda; a_ptr += lda;
ix += inc_x; ix += inc_x;
} }
return(0);
return (0);
} }


for (j = 0; j < n; j++) { for (j = 0; j < n; j++) {
@@ -89,4 +148,4 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
ix += inc_x; ix += inc_x;
} }
return (0); return (0);
}
}

Loading…
Cancel
Save