| @@ -182,8 +182,8 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, | |||||
| i = m % sve_size; | i = m % sve_size; | ||||
| if (i) { | if (i) { | ||||
| aa = a + ((m & ~(i - 1)) - i) * k * COMPSIZE; | |||||
| cc = c + ((m & ~(i - 1)) - i) * COMPSIZE; | |||||
| aa = a + (m - i) * k * COMPSIZE; | |||||
| cc = c + (m - i) * COMPSIZE; | |||||
| if (k - kk > 0) { | if (k - kk > 0) { | ||||
| GEMM_KERNEL(i, GEMM_UNROLL_N, k - kk, dm1, | GEMM_KERNEL(i, GEMM_UNROLL_N, k - kk, dm1, | ||||
| @@ -205,10 +205,11 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, | |||||
| } | } | ||||
| int mod = i; | |||||
| i = sve_size; | i = sve_size; | ||||
| if (i <= m) { | if (i <= m) { | ||||
| aa = a + ((m & ~(sve_size - 1)) - sve_size) * k * COMPSIZE; | |||||
| cc = c + ((m & ~(sve_size - 1)) - sve_size) * COMPSIZE; | |||||
| aa = a + (m - mod - sve_size) * k * COMPSIZE; | |||||
| cc = c + (m - mod - sve_size) * COMPSIZE; | |||||
| do { | do { | ||||
| if (k - kk > 0) { | if (k - kk > 0) { | ||||
| @@ -217,7 +218,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, | |||||
| ZERO, | ZERO, | ||||
| #endif | #endif | ||||
| aa + sve_size * kk * COMPSIZE, | aa + sve_size * kk * COMPSIZE, | ||||
| b + sve_size * kk * COMPSIZE, | |||||
| b + GEMM_UNROLL_N * kk * COMPSIZE, | |||||
| cc, | cc, | ||||
| ldc); | ldc); | ||||
| } | } | ||||
| @@ -251,8 +252,8 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, | |||||
| i = m % sve_size; | i = m % sve_size; | ||||
| if (i) { | if (i) { | ||||
| aa = a + ((m & ~(i - 1)) - i) * k * COMPSIZE; | |||||
| cc = c + ((m & ~(i - 1)) - i) * COMPSIZE; | |||||
| aa = a + (m - i) * k * COMPSIZE; | |||||
| cc = c + (m - i) * COMPSIZE; | |||||
| if (k - kk > 0) { | if (k - kk > 0) { | ||||
| GEMM_KERNEL(i, j, k - kk, dm1, | GEMM_KERNEL(i, j, k - kk, dm1, | ||||
| @@ -273,10 +274,11 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, | |||||
| } | } | ||||
| int mod = i; | |||||
| i = sve_size; | i = sve_size; | ||||
| if (i <= m) { | if (i <= m) { | ||||
| aa = a + ((m & ~(sve_size - 1)) - sve_size) * k * COMPSIZE; | |||||
| cc = c + ((m & ~(sve_size - 1)) - sve_size) * COMPSIZE; | |||||
| aa = a + (m - mod - sve_size) * k * COMPSIZE; | |||||
| cc = c + (m - mod - sve_size) * COMPSIZE; | |||||
| do { | do { | ||||
| if (k - kk > 0) { | if (k - kk > 0) { | ||||
| @@ -257,7 +257,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, | |||||
| i += sve_size; | i += sve_size; | ||||
| } | } | ||||
| i = sve_size % m; | |||||
| i = m % sve_size; | |||||
| if (i) { | if (i) { | ||||
| if (kk > 0) { | if (kk > 0) { | ||||
| GEMM_KERNEL(i, j, kk, dm1, | GEMM_KERNEL(i, j, kk, dm1, | ||||
| @@ -258,23 +258,23 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, | |||||
| if (i <= m) { | if (i <= m) { | ||||
| do { | do { | ||||
| if (k - kk > 0) { | if (k - kk > 0) { | ||||
| GEMM_KERNEL(GEMM_UNROLL_M, GEMM_UNROLL_N, k - kk, dm1, | |||||
| GEMM_KERNEL(sve_size, GEMM_UNROLL_N, k - kk, dm1, | |||||
| #ifdef COMPLEX | #ifdef COMPLEX | ||||
| ZERO, | ZERO, | ||||
| #endif | #endif | ||||
| aa + GEMM_UNROLL_M * kk * COMPSIZE, | |||||
| aa + sve_size * kk * COMPSIZE, | |||||
| b + GEMM_UNROLL_N * kk * COMPSIZE, | b + GEMM_UNROLL_N * kk * COMPSIZE, | ||||
| cc, | cc, | ||||
| ldc); | ldc); | ||||
| } | } | ||||
| solve(GEMM_UNROLL_M, GEMM_UNROLL_N, | |||||
| aa + (kk - GEMM_UNROLL_N) * GEMM_UNROLL_M * COMPSIZE, | |||||
| solve(sve_size, GEMM_UNROLL_N, | |||||
| aa + (kk - GEMM_UNROLL_N) * sve_size * COMPSIZE, | |||||
| b + (kk - GEMM_UNROLL_N) * GEMM_UNROLL_N * COMPSIZE, | b + (kk - GEMM_UNROLL_N) * GEMM_UNROLL_N * COMPSIZE, | ||||
| cc, ldc); | cc, ldc); | ||||
| aa += GEMM_UNROLL_M * k * COMPSIZE; | |||||
| cc += GEMM_UNROLL_M * COMPSIZE; | |||||
| aa += sve_size * k * COMPSIZE; | |||||
| cc += sve_size * COMPSIZE; | |||||
| i += sve_size; | i += sve_size; | ||||
| } while (i <= m); | } while (i <= m); | ||||
| } | } | ||||
| @@ -48,17 +48,18 @@ | |||||
| int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){ | int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){ | ||||
| BLASLONG i, ii, j, jj; | |||||
| BLASLONG i, ii, jj; | |||||
| FLOAT *ao; | FLOAT *ao; | ||||
| jj = offset; | jj = offset; | ||||
| int js = 0; | |||||
| #ifdef DOUBLE | #ifdef DOUBLE | ||||
| int64_t js = 0; | |||||
| svint64_t index = svindex_s64(0LL, lda); | svint64_t index = svindex_s64(0LL, lda); | ||||
| svbool_t pn = svwhilelt_b64(js, n); | svbool_t pn = svwhilelt_b64(js, n); | ||||
| int n_active = svcntp_b64(svptrue_b64(), pn); | int n_active = svcntp_b64(svptrue_b64(), pn); | ||||
| #else | #else | ||||
| int32_t js = 0; | |||||
| svint32_t index = svindex_s32(0, lda); | svint32_t index = svindex_s32(0, lda); | ||||
| svbool_t pn = svwhilelt_b32(js, n); | svbool_t pn = svwhilelt_b32(js, n); | ||||
| int n_active = svcntp_b32(svptrue_b32(), pn); | int n_active = svcntp_b32(svptrue_b32(), pn); | ||||
| @@ -74,25 +75,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT | |||||
| if (ii == jj) { | if (ii == jj) { | ||||
| for (int j = 0; j < n_active; j++) { | for (int j = 0; j < n_active; j++) { | ||||
| for (int k = 0; k < j; k++) { | for (int k = 0; k < j; k++) { | ||||
| *(b + j * n_active + k) = *(a + k * lda + j); | |||||
| *(b + j * n_active + k) = *(ao + k * lda + j); | |||||
| } | } | ||||
| *(b + j * n_active + j) = INV(*(a + j * lda + j)); | |||||
| *(b + j * n_active + j) = INV(*(ao + j * lda + j)); | |||||
| } | } | ||||
| } | |||||
| if (ii > jj) { | |||||
| for (int j = 0; j < n_active; j++) { | |||||
| ao += n_active; | |||||
| b += n_active * n_active; | |||||
| i += n_active; | |||||
| ii += n_active; | |||||
| } else { | |||||
| if (ii > jj) { | |||||
| svfloat64_t aj_vec = svld1_gather_index(pn, ao, index); | svfloat64_t aj_vec = svld1_gather_index(pn, ao, index); | ||||
| svst1(pn, b, aj_vec); | svst1(pn, b, aj_vec); | ||||
| ao++; | |||||
| } | } | ||||
| ao++; | |||||
| b += n_active; | |||||
| i++; | |||||
| ii++; | |||||
| } | } | ||||
| b += n_active * n_active; | |||||
| i += n_active; | |||||
| ii += n_active; | |||||
| } while (i < m); | } while (i < m); | ||||
| @@ -48,18 +48,17 @@ | |||||
| int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){ | int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){ | ||||
| BLASLONG i, ii, j, jj; | |||||
| BLASLONG i, ii, jj; | |||||
| FLOAT *ao; | FLOAT *ao; | ||||
| jj = offset; | jj = offset; | ||||
| int js = 0; | |||||
| #ifdef DOUBLE | #ifdef DOUBLE | ||||
| svint64_t index = svindex_s64(0LL, lda); | |||||
| int64_t js = 0; | |||||
| svbool_t pn = svwhilelt_b64(js, n); | svbool_t pn = svwhilelt_b64(js, n); | ||||
| int n_active = svcntp_b64(svptrue_b64(), pn); | int n_active = svcntp_b64(svptrue_b64(), pn); | ||||
| #else | #else | ||||
| svint32_t index = svindex_s32(0, lda); | |||||
| int32_t js = 0; | |||||
| svbool_t pn = svwhilelt_b32(js, n); | svbool_t pn = svwhilelt_b32(js, n); | ||||
| int n_active = svcntp_b32(svptrue_b32(), pn); | int n_active = svcntp_b32(svptrue_b32(), pn); | ||||
| #endif | #endif | ||||
| @@ -73,26 +72,25 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT | |||||
| if (ii == jj) { | if (ii == jj) { | ||||
| for (int j = 0; j < n_active; j++) { | for (int j = 0; j < n_active; j++) { | ||||
| *(b + j * n_active + j) = INV(*(a + j * lda + j)); | |||||
| *(b + j * n_active + j) = INV(*(ao + j * lda + j)); | |||||
| for (int k = j+1; k < n_active; k++) { | for (int k = j+1; k < n_active; k++) { | ||||
| *(b + j * n_active + k) = *(a + j * lda + k); | |||||
| *(b + j * n_active + k) = *(ao + j * lda + k); | |||||
| } | } | ||||
| } | } | ||||
| } | |||||
| if (ii < jj) { | |||||
| for (int j = 0; j < n_active; j++) { | |||||
| b += n_active * n_active; | |||||
| ao += lda * n_active; | |||||
| i += n_active; | |||||
| ii += n_active; | |||||
| } else { | |||||
| if (ii < jj) { | |||||
| svfloat64_t aj_vec = svld1(pn, ao); | svfloat64_t aj_vec = svld1(pn, ao); | ||||
| svst1(pn, b, aj_vec); | svst1(pn, b, aj_vec); | ||||
| ao += lda; | |||||
| } | } | ||||
| ao += lda; | |||||
| b += n_active; | |||||
| i ++; | |||||
| ii ++; | |||||
| } | } | ||||
| b += n_active * n_active; | |||||
| i += n_active; | |||||
| ii += n_active; | |||||
| } while (i < m); | } while (i < m); | ||||
| @@ -48,17 +48,18 @@ | |||||
| int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){ | int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){ | ||||
| BLASLONG i, ii, j, jj; | |||||
| BLASLONG i, ii, jj; | |||||
| FLOAT *ao; | FLOAT *ao; | ||||
| jj = offset; | jj = offset; | ||||
| int js = 0; | |||||
| #ifdef DOUBLE | #ifdef DOUBLE | ||||
| int64_t js = 0; | |||||
| svint64_t index = svindex_s64(0LL, lda); | svint64_t index = svindex_s64(0LL, lda); | ||||
| svbool_t pn = svwhilelt_b64(js, n); | svbool_t pn = svwhilelt_b64(js, n); | ||||
| int n_active = svcntp_b64(svptrue_b64(), pn); | int n_active = svcntp_b64(svptrue_b64(), pn); | ||||
| #else | #else | ||||
| int32_t js = 0; | |||||
| svint32_t index = svindex_s32(0, lda); | svint32_t index = svindex_s32(0, lda); | ||||
| svbool_t pn = svwhilelt_b32(js, n); | svbool_t pn = svwhilelt_b32(js, n); | ||||
| int n_active = svcntp_b32(svptrue_b32(), pn); | int n_active = svcntp_b32(svptrue_b32(), pn); | ||||
| @@ -73,25 +74,25 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT | |||||
| if (ii == jj) { | if (ii == jj) { | ||||
| for (int j = 0; j < n_active; j++) { | for (int j = 0; j < n_active; j++) { | ||||
| *(b + j * n_active + j) = INV(*(a + j * lda + j)); | |||||
| *(b + j * n_active + j) = INV(*(ao + j * lda + j)); | |||||
| for (int k = j+1; k < n_active; k++) { | for (int k = j+1; k < n_active; k++) { | ||||
| *(b + j * n_active + k) = *(a + k * lda + j); | |||||
| *(b + j * n_active + k) = *(ao + k * lda + j); | |||||
| } | } | ||||
| } | } | ||||
| } | |||||
| if (ii < jj) { | |||||
| for (int j = 0; j < n_active; j++) { | |||||
| ao += n_active; | |||||
| b += n_active * n_active; | |||||
| i += n_active; | |||||
| ii += n_active; | |||||
| } else { | |||||
| if (ii < jj) { | |||||
| svfloat64_t aj_vec = svld1_gather_index(pn, ao, index); | svfloat64_t aj_vec = svld1_gather_index(pn, ao, index); | ||||
| svst1(pn, b, aj_vec); | svst1(pn, b, aj_vec); | ||||
| ao++; | |||||
| } | } | ||||
| ao++; | |||||
| b += n_active; | |||||
| i++; | |||||
| ii++; | |||||
| } | } | ||||
| b += n_active * n_active; | |||||
| i += n_active; | |||||
| ii += n_active; | |||||
| } while (i < m); | } while (i < m); | ||||
| @@ -48,18 +48,17 @@ | |||||
| int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){ | int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){ | ||||
| BLASLONG i, ii, j, jj; | |||||
| BLASLONG i, ii, jj; | |||||
| FLOAT *ao; | FLOAT *ao; | ||||
| jj = offset; | jj = offset; | ||||
| int js = 0; | |||||
| #ifdef DOUBLE | #ifdef DOUBLE | ||||
| svint64_t index = svindex_s64(0LL, lda); | |||||
| int64_t js = 0; | |||||
| svbool_t pn = svwhilelt_b64(js, n); | svbool_t pn = svwhilelt_b64(js, n); | ||||
| int n_active = svcntp_b64(svptrue_b64(), pn); | int n_active = svcntp_b64(svptrue_b64(), pn); | ||||
| #else | #else | ||||
| svint32_t index = svindex_s32(0, lda); | |||||
| int32_t js = 0; | |||||
| svbool_t pn = svwhilelt_b32(js, n); | svbool_t pn = svwhilelt_b32(js, n); | ||||
| int n_active = svcntp_b32(svptrue_b32(), pn); | int n_active = svcntp_b32(svptrue_b32(), pn); | ||||
| #endif | #endif | ||||
| @@ -74,25 +73,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT | |||||
| if (ii == jj) { | if (ii == jj) { | ||||
| for (int j = 0; j < n_active; j++) { | for (int j = 0; j < n_active; j++) { | ||||
| for (int k = 0; k < j; k++) { | for (int k = 0; k < j; k++) { | ||||
| *(b + j * n_active + k) = *(a + j * lda + k); | |||||
| *(b + j * n_active + k) = *(ao + j * lda + k); | |||||
| } | } | ||||
| *(b + j * n_active + j) = INV(*(a + j * lda + j)); | |||||
| *(b + j * n_active + j) = INV(*(ao + j * lda + j)); | |||||
| } | } | ||||
| } | |||||
| if (ii > jj) { | |||||
| for (int j = 0; j < n_active; j++) { | |||||
| ao += lda * n_active; | |||||
| b += n_active * n_active; | |||||
| i += n_active; | |||||
| ii += n_active; | |||||
| } else { | |||||
| if (ii > jj) { | |||||
| svfloat64_t aj_vec = svld1(pn, ao); | svfloat64_t aj_vec = svld1(pn, ao); | ||||
| svst1(pn, b, aj_vec); | svst1(pn, b, aj_vec); | ||||
| ao += lda; | |||||
| } | } | ||||
| } | |||||
| b += n_active * n_active; | |||||
| i += n_active; | |||||
| ii += n_active; | |||||
| ao += lda; | |||||
| b += n_active; | |||||
| i ++; | |||||
| ii ++; | |||||
| } | |||||
| } while (i < m); | } while (i < m); | ||||