From c37509c213a34a8cae449ededd7bc7064675ecc4 Mon Sep 17 00:00:00 2001 From: "tingbo.liao" Date: Tue, 31 Dec 2024 08:46:55 +0800 Subject: [PATCH] Optimize the nrm2_rvv function to further improve performance. Signed-off-by: tingbo.liao --- kernel/riscv64/nrm2_rvv.c | 370 +++++++++++++++++++++----------------- 1 file changed, 204 insertions(+), 166 deletions(-) diff --git a/kernel/riscv64/nrm2_rvv.c b/kernel/riscv64/nrm2_rvv.c index 14ed68b0a..472b1148e 100644 --- a/kernel/riscv64/nrm2_rvv.c +++ b/kernel/riscv64/nrm2_rvv.c @@ -27,185 +27,223 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" -#if defined(DOUBLE) -#define VSETVL __riscv_vsetvl_e64m4 -#define FLOAT_V_T vfloat64m4_t -#define FLOAT_V_T_M1 vfloat64m1_t -#define VLEV_FLOAT __riscv_vle64_v_f64m4 -#define VLSEV_FLOAT __riscv_vlse64_v_f64m4 -#define VFMVVF_FLOAT __riscv_vfmv_v_f_f64m4 -#define VFMVSF_FLOAT __riscv_vfmv_s_f_f64m4 -#define VFMVVF_FLOAT_M1 __riscv_vfmv_v_f_f64m1 -#define MASK_T vbool16_t -#define VFABS __riscv_vfabs_v_f64m4 -#define VMFNE __riscv_vmfne_vf_f64m4_b16 -#define VMFGT __riscv_vmfgt_vv_f64m4_b16 -#define VMFEQ __riscv_vmfeq_vf_f64m4_b16 -#define VCPOP __riscv_vcpop_m_b16 -#define VFREDMAX __riscv_vfredmax_vs_f64m4_f64m1 -#define VFREDMIN __riscv_vfredmin_vs_f64m4_f64m1 -#define VFIRST __riscv_vfirst_m_b16 -#define VRGATHER __riscv_vrgather_vx_f64m4 -#define VFDIV __riscv_vfdiv_vv_f64m4 -#define VFDIV_M __riscv_vfdiv_vv_f64m4_mu -#define VFMUL __riscv_vfmul_vv_f64m4 -#define VFMUL_M __riscv_vfmul_vv_f64m4_mu -#define VFMACC __riscv_vfmacc_vv_f64m4 -#define VFMACC_M __riscv_vfmacc_vv_f64m4_mu -#define VMSBF __riscv_vmsbf_m_b16 -#define VMSOF __riscv_vmsof_m_b16 -#define VMAND __riscv_vmand_mm_b16 -#define VMANDN __riscv_vmand_mm_b16 -#define VFREDSUM __riscv_vfredusum_vs_f64m4_f64m1 -#define VMERGE __riscv_vmerge_vvm_f64m4 -#define VSEV_FLOAT __riscv_vse64_v_f64m4 -#define EXTRACT_FLOAT0_V(v) __riscv_vfmv_f_s_f64m4_f64(v) -#define ABS fabs -#else -#define VSETVL __riscv_vsetvl_e32m4 +#if !defined(DOUBLE) +#define VSETVL(n) __riscv_vsetvl_e32m4(n) +#define VSETVL_MAX __riscv_vsetvlmax_e32m4() #define FLOAT_V_T vfloat32m4_t #define FLOAT_V_T_M1 vfloat32m1_t +#define MASK_T vbool8_t #define VLEV_FLOAT __riscv_vle32_v_f32m4 #define VLSEV_FLOAT __riscv_vlse32_v_f32m4 +#define VFREDSUM_FLOAT __riscv_vfredusum_vs_f32m4_f32m1_tu +#define VFMACCVV_FLOAT_TU __riscv_vfmacc_vv_f32m4_tu #define VFMVVF_FLOAT __riscv_vfmv_v_f_f32m4 -#define VFMVSF_FLOAT __riscv_vfmv_s_f_f32m4 #define VFMVVF_FLOAT_M1 __riscv_vfmv_v_f_f32m1 -#define MASK_T vbool8_t -#define VFABS __riscv_vfabs_v_f32m4 -#define VMFNE __riscv_vmfne_vf_f32m4_b8 -#define VMFGT __riscv_vmfgt_vv_f32m4_b8 -#define VMFEQ __riscv_vmfeq_vf_f32m4_b8 -#define VCPOP __riscv_vcpop_m_b8 -#define VFREDMAX __riscv_vfredmax_vs_f32m4_f32m1 -#define VFREDMIN __riscv_vfredmin_vs_f32m4_f32m1 -#define VFIRST __riscv_vfirst_m_b8 -#define VRGATHER __riscv_vrgather_vx_f32m4 -#define VFDIV __riscv_vfdiv_vv_f32m4 -#define VFDIV_M __riscv_vfdiv_vv_f32m4_mu -#define VFMUL __riscv_vfmul_vv_f32m4 -#define VFMUL_M __riscv_vfmul_vv_f32m4_mu -#define VFMACC __riscv_vfmacc_vv_f32m4 -#define VFMACC_M __riscv_vfmacc_vv_f32m4_mu -#define VMSBF __riscv_vmsbf_m_b8 -#define VMSOF __riscv_vmsof_m_b8 -#define VMAND __riscv_vmand_mm_b8 -#define VMANDN __riscv_vmand_mm_b8 -#define VFREDSUM __riscv_vfredusum_vs_f32m4_f32m1 -#define VMERGE __riscv_vmerge_vvm_f32m4 -#define VSEV_FLOAT __riscv_vse32_v_f32m4 -#define EXTRACT_FLOAT0_V(v) __riscv_vfmv_f_s_f32m4_f32(v) +#define VMFIRSTM __riscv_vfirst_m_b8 +#define VFREDMAXVS_FLOAT_TU __riscv_vfredmax_vs_f32m4_f32m1_tu +#define VFMVFS_FLOAT __riscv_vfmv_f_s_f32m1_f32 +#define VMFGTVF_FLOAT __riscv_vmfgt_vf_f32m4_b8 +#define VFDIVVF_FLOAT __riscv_vfdiv_vf_f32m4 +#define VFABSV_FLOAT __riscv_vfabs_v_f32m4 #define ABS fabsf +#else +#define VSETVL(n) __riscv_vsetvl_e64m4(n) +#define VSETVL_MAX __riscv_vsetvlmax_e64m4() +#define FLOAT_V_T vfloat64m4_t +#define FLOAT_V_T_M1 vfloat64m1_t +#define MASK_T vbool16_t +#define VLEV_FLOAT __riscv_vle64_v_f64m4 +#define VLSEV_FLOAT __riscv_vlse64_v_f64m4 +#define VFREDSUM_FLOAT __riscv_vfredusum_vs_f64m4_f64m1_tu +#define VFMACCVV_FLOAT_TU __riscv_vfmacc_vv_f64m4_tu +#define VFMVVF_FLOAT __riscv_vfmv_v_f_f64m4 +#define VFMVVF_FLOAT_M1 __riscv_vfmv_v_f_f64m1 +#define VMFIRSTM __riscv_vfirst_m_b16 +#define VFREDMAXVS_FLOAT_TU __riscv_vfredmax_vs_f64m4_f64m1_tu +#define VFMVFS_FLOAT __riscv_vfmv_f_s_f64m1_f64 +#define VMFGTVF_FLOAT __riscv_vmfgt_vf_f64m4_b16 +#define VFDIVVF_FLOAT __riscv_vfdiv_vf_f64m4 +#define VFABSV_FLOAT __riscv_vfabs_v_f64m4 +#define ABS fabs #endif FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x) { - BLASLONG i=0; - - if (n <= 0 || inc_x == 0) return(0.0); - if(n == 1) return (ABS(x[0])); - - unsigned int gvl = 0; - - MASK_T nonzero_mask; - MASK_T scale_mask; - - gvl = VSETVL(n); - FLOAT_V_T v0; - FLOAT_V_T v_ssq = VFMVVF_FLOAT(0, gvl); - FLOAT_V_T v_scale = VFMVVF_FLOAT(0, gvl); - - FLOAT scale = 0; - FLOAT ssq = 0; - unsigned int stride_x = inc_x * sizeof(FLOAT); - int idx = 0; - - if( n >= gvl && inc_x > 0 ) // don't pay overheads if we're not doing useful work - { - for(i=0; i 0 ){ + FLOAT_V_T vr, v0, v_zero; + unsigned int gvl = 0; + FLOAT_V_T_M1 v_res, v_z0; + gvl = VSETVL_MAX; + v_res = VFMVVF_FLOAT_M1(0, gvl); + v_z0 = VFMVVF_FLOAT_M1(0, gvl); + MASK_T mask; + BLASLONG index = 0; + + if (inc_x == 1) { + gvl = VSETVL(n); + vr = VFMVVF_FLOAT(0, gvl); + v_zero = VFMVVF_FLOAT(0, gvl); + for (i = 0, j = 0; i < n / gvl; i++) { + v0 = VLEV_FLOAT(&x[j], gvl); + // fabs(vector) + v0 = VFABSV_FLOAT(v0, gvl); + // if scale change + mask = VMFGTVF_FLOAT(v0, scale, gvl); + index = VMFIRSTM(mask, gvl); + if (index == -1) { // no elements greater than scale + if (scale != 0.0) { + v0 = VFDIVVF_FLOAT(v0, scale, gvl); + vr = VFMACCVV_FLOAT_TU(vr, v0, v0, gvl); + } + } + else { // found greater element + // ssq in vector vr: vr[0] + v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl); + // total ssq before current vector + ssq += VFMVFS_FLOAT(v_res); + // find max + v_res = VFREDMAXVS_FLOAT_TU(v_res, v0, v_z0, gvl); + // update ssq before max_index + ssq = ssq * (scale / VFMVFS_FLOAT(v_res)) * (scale / VFMVFS_FLOAT(v_res)); + // update scale + scale = VFMVFS_FLOAT(v_res); + // ssq in vector vr + v0 = VFDIVVF_FLOAT(v0, scale, gvl); + vr = VFMACCVV_FLOAT_TU(v_zero, v0, v0, gvl); + } + j += gvl; + } + // ssq in vector vr: vr[0] + v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl); + // total ssq now + ssq += VFMVFS_FLOAT(v_res); + + // tail processing + if(j < n){ + gvl = VSETVL(n-j); + v0 = VLEV_FLOAT(&x[j], gvl); + // fabs(vector) + v0 = VFABSV_FLOAT(v0, gvl); + // if scale change + mask = VMFGTVF_FLOAT(v0, scale, gvl); + index = VMFIRSTM(mask, gvl); + if (index == -1) { // no elements greater than scale + if(scale != 0.0) + v0 = VFDIVVF_FLOAT(v0, scale, gvl); + } else { // found greater element + // find max + v_res = VFREDMAXVS_FLOAT_TU(v_res, v0, v_z0, gvl); + // update ssq before max_index + ssq = ssq * (scale / VFMVFS_FLOAT(v_res))*(scale / VFMVFS_FLOAT(v_res)); + // update scale + scale = VFMVFS_FLOAT(v_res); + v0 = VFDIVVF_FLOAT(v0, scale, gvl); + } + vr = VFMACCVV_FLOAT_TU(v_zero, v0, v0, gvl); + // ssq in vector vr: vr[0] + v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl); + // total ssq now + ssq += VFMVFS_FLOAT(v_res); + } + } + else { + gvl = VSETVL(n); + vr = VFMVVF_FLOAT(0, gvl); + v_zero = VFMVVF_FLOAT(0, gvl); + unsigned int stride_x = inc_x * sizeof(FLOAT); + int idx = 0, inc_v = inc_x * gvl; + for (i = 0, j = 0; i < n / gvl; i++) { + v0 = VLSEV_FLOAT(&x[idx], stride_x, gvl); + // fabs(vector) + v0 = VFABSV_FLOAT(v0, gvl); + // if scale change + mask = VMFGTVF_FLOAT(v0, scale, gvl); + index = VMFIRSTM(mask, gvl); + if (index == -1) {// no elements greater than scale + if(scale != 0.0){ + v0 = VFDIVVF_FLOAT(v0, scale, gvl); + vr = VFMACCVV_FLOAT_TU(vr, v0, v0, gvl); + } + } + else { // found greater element + // ssq in vector vr: vr[0] + v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl); + // total ssq before current vector + ssq += VFMVFS_FLOAT(v_res); + // find max + v_res = VFREDMAXVS_FLOAT_TU(v_res, v0, v_z0, gvl); + // update ssq before max_index + ssq = ssq * (scale / VFMVFS_FLOAT(v_res))*(scale / VFMVFS_FLOAT(v_res)); + // update scale + scale = VFMVFS_FLOAT(v_res); + // ssq in vector vr + v0 = VFDIVVF_FLOAT(v0, scale, gvl); + vr = VFMACCVV_FLOAT_TU(v_zero, v0, v0, gvl); + } + j += gvl; + idx += inc_v; + } + // ssq in vector vr: vr[0] + v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl); + // total ssq now + ssq += VFMVFS_FLOAT(v_res); + + // tail processing + if (j < n) { + gvl = VSETVL(n-j); + v0 = VLSEV_FLOAT(&x[idx], stride_x, gvl); + // fabs(vector) + v0 = VFABSV_FLOAT(v0, gvl); + // if scale change + mask = VMFGTVF_FLOAT(v0, scale, gvl); + index = VMFIRSTM(mask, gvl); + if(index == -1) { // no elements greater than scale + if(scale != 0.0) { + v0 = VFDIVVF_FLOAT(v0, scale, gvl); + vr = VFMACCVV_FLOAT_TU(v_zero, v0, v0, gvl); + } + } + else { // found greater element + // find max + v_res = VFREDMAXVS_FLOAT_TU(v_res, v0, v_z0, gvl); + // update ssq before max_index + ssq = ssq * (scale / VFMVFS_FLOAT(v_res))*(scale / VFMVFS_FLOAT(v_res)); + // update scale + scale = VFMVFS_FLOAT(v_res); + v0 = VFDIVVF_FLOAT(v0, scale, gvl); + vr = VFMACCVV_FLOAT_TU(v_zero, v0, v0, gvl); + } + // ssq in vector vr: vr[0] + v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl); + // total ssq now + ssq += VFMVFS_FLOAT(v_res); + } + } + } + else{ + // using scalar ops when inc_x < 0 + n *= inc_x; while(abs(i) < abs(n)){ - if ( x[i] != 0.0 ){ - FLOAT absxi = ABS( x[i] ); - if ( scale < absxi ){ - ssq = 1 + ssq * ( scale / absxi ) * ( scale / absxi ); - scale = absxi ; - } - else{ - ssq += ( absxi/scale ) * ( absxi/scale ); - } - - } - - i += inc_x; + if ( x[i] != 0.0 ){ + FLOAT absxi = ABS( x[i] ); + if ( scale < absxi ){ + ssq = 1 + ssq * ( scale / absxi ) * ( scale / absxi ); + scale = absxi ; + } + else{ + ssq += ( absxi/scale ) * ( absxi/scale ); + } + + } + i += inc_x; } - + } return(scale * sqrt(ssq)); }