Browse Source

Merge pull request #5037 from tingboliao/develop

Optimize the nrm2_rvv function to further improve performance.
tags/v0.3.29
Martin Kroeker GitHub 1 year ago
parent
commit
a63282a688
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
1 changed files with 204 additions and 166 deletions
  1. +204
    -166
      kernel/riscv64/nrm2_rvv.c

+ 204
- 166
kernel/riscv64/nrm2_rvv.c View File

@@ -27,185 +27,223 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "common.h"

#if defined(DOUBLE)
#define VSETVL __riscv_vsetvl_e64m4
#define FLOAT_V_T vfloat64m4_t
#define FLOAT_V_T_M1 vfloat64m1_t
#define VLEV_FLOAT __riscv_vle64_v_f64m4
#define VLSEV_FLOAT __riscv_vlse64_v_f64m4
#define VFMVVF_FLOAT __riscv_vfmv_v_f_f64m4
#define VFMVSF_FLOAT __riscv_vfmv_s_f_f64m4
#define VFMVVF_FLOAT_M1 __riscv_vfmv_v_f_f64m1
#define MASK_T vbool16_t
#define VFABS __riscv_vfabs_v_f64m4
#define VMFNE __riscv_vmfne_vf_f64m4_b16
#define VMFGT __riscv_vmfgt_vv_f64m4_b16
#define VMFEQ __riscv_vmfeq_vf_f64m4_b16
#define VCPOP __riscv_vcpop_m_b16
#define VFREDMAX __riscv_vfredmax_vs_f64m4_f64m1
#define VFREDMIN __riscv_vfredmin_vs_f64m4_f64m1
#define VFIRST __riscv_vfirst_m_b16
#define VRGATHER __riscv_vrgather_vx_f64m4
#define VFDIV __riscv_vfdiv_vv_f64m4
#define VFDIV_M __riscv_vfdiv_vv_f64m4_mu
#define VFMUL __riscv_vfmul_vv_f64m4
#define VFMUL_M __riscv_vfmul_vv_f64m4_mu
#define VFMACC __riscv_vfmacc_vv_f64m4
#define VFMACC_M __riscv_vfmacc_vv_f64m4_mu
#define VMSBF __riscv_vmsbf_m_b16
#define VMSOF __riscv_vmsof_m_b16
#define VMAND __riscv_vmand_mm_b16
#define VMANDN __riscv_vmand_mm_b16
#define VFREDSUM __riscv_vfredusum_vs_f64m4_f64m1
#define VMERGE __riscv_vmerge_vvm_f64m4
#define VSEV_FLOAT __riscv_vse64_v_f64m4
#define EXTRACT_FLOAT0_V(v) __riscv_vfmv_f_s_f64m4_f64(v)
#define ABS fabs
#else
#define VSETVL __riscv_vsetvl_e32m4
#if !defined(DOUBLE)
#define VSETVL(n) __riscv_vsetvl_e32m4(n)
#define VSETVL_MAX __riscv_vsetvlmax_e32m4()
#define FLOAT_V_T vfloat32m4_t
#define FLOAT_V_T_M1 vfloat32m1_t
#define MASK_T vbool8_t
#define VLEV_FLOAT __riscv_vle32_v_f32m4
#define VLSEV_FLOAT __riscv_vlse32_v_f32m4
#define VFREDSUM_FLOAT __riscv_vfredusum_vs_f32m4_f32m1_tu
#define VFMACCVV_FLOAT_TU __riscv_vfmacc_vv_f32m4_tu
#define VFMVVF_FLOAT __riscv_vfmv_v_f_f32m4
#define VFMVSF_FLOAT __riscv_vfmv_s_f_f32m4
#define VFMVVF_FLOAT_M1 __riscv_vfmv_v_f_f32m1
#define MASK_T vbool8_t
#define VFABS __riscv_vfabs_v_f32m4
#define VMFNE __riscv_vmfne_vf_f32m4_b8
#define VMFGT __riscv_vmfgt_vv_f32m4_b8
#define VMFEQ __riscv_vmfeq_vf_f32m4_b8
#define VCPOP __riscv_vcpop_m_b8
#define VFREDMAX __riscv_vfredmax_vs_f32m4_f32m1
#define VFREDMIN __riscv_vfredmin_vs_f32m4_f32m1
#define VFIRST __riscv_vfirst_m_b8
#define VRGATHER __riscv_vrgather_vx_f32m4
#define VFDIV __riscv_vfdiv_vv_f32m4
#define VFDIV_M __riscv_vfdiv_vv_f32m4_mu
#define VFMUL __riscv_vfmul_vv_f32m4
#define VFMUL_M __riscv_vfmul_vv_f32m4_mu
#define VFMACC __riscv_vfmacc_vv_f32m4
#define VFMACC_M __riscv_vfmacc_vv_f32m4_mu
#define VMSBF __riscv_vmsbf_m_b8
#define VMSOF __riscv_vmsof_m_b8
#define VMAND __riscv_vmand_mm_b8
#define VMANDN __riscv_vmand_mm_b8
#define VFREDSUM __riscv_vfredusum_vs_f32m4_f32m1
#define VMERGE __riscv_vmerge_vvm_f32m4
#define VSEV_FLOAT __riscv_vse32_v_f32m4
#define EXTRACT_FLOAT0_V(v) __riscv_vfmv_f_s_f32m4_f32(v)
#define VMFIRSTM __riscv_vfirst_m_b8
#define VFREDMAXVS_FLOAT_TU __riscv_vfredmax_vs_f32m4_f32m1_tu
#define VFMVFS_FLOAT __riscv_vfmv_f_s_f32m1_f32
#define VMFGTVF_FLOAT __riscv_vmfgt_vf_f32m4_b8
#define VFDIVVF_FLOAT __riscv_vfdiv_vf_f32m4
#define VFABSV_FLOAT __riscv_vfabs_v_f32m4
#define ABS fabsf
#else
#define VSETVL(n) __riscv_vsetvl_e64m4(n)
#define VSETVL_MAX __riscv_vsetvlmax_e64m4()
#define FLOAT_V_T vfloat64m4_t
#define FLOAT_V_T_M1 vfloat64m1_t
#define MASK_T vbool16_t
#define VLEV_FLOAT __riscv_vle64_v_f64m4
#define VLSEV_FLOAT __riscv_vlse64_v_f64m4
#define VFREDSUM_FLOAT __riscv_vfredusum_vs_f64m4_f64m1_tu
#define VFMACCVV_FLOAT_TU __riscv_vfmacc_vv_f64m4_tu
#define VFMVVF_FLOAT __riscv_vfmv_v_f_f64m4
#define VFMVVF_FLOAT_M1 __riscv_vfmv_v_f_f64m1
#define VMFIRSTM __riscv_vfirst_m_b16
#define VFREDMAXVS_FLOAT_TU __riscv_vfredmax_vs_f64m4_f64m1_tu
#define VFMVFS_FLOAT __riscv_vfmv_f_s_f64m1_f64
#define VMFGTVF_FLOAT __riscv_vmfgt_vf_f64m4_b16
#define VFDIVVF_FLOAT __riscv_vfdiv_vf_f64m4
#define VFABSV_FLOAT __riscv_vfabs_v_f64m4
#define ABS fabs
#endif

FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
{
BLASLONG i=0;

if (n <= 0 || inc_x == 0) return(0.0);
if(n == 1) return (ABS(x[0]));

unsigned int gvl = 0;

MASK_T nonzero_mask;
MASK_T scale_mask;

gvl = VSETVL(n);
FLOAT_V_T v0;
FLOAT_V_T v_ssq = VFMVVF_FLOAT(0, gvl);
FLOAT_V_T v_scale = VFMVVF_FLOAT(0, gvl);

FLOAT scale = 0;
FLOAT ssq = 0;
unsigned int stride_x = inc_x * sizeof(FLOAT);
int idx = 0;

if( n >= gvl && inc_x > 0 ) // don't pay overheads if we're not doing useful work
{
for(i=0; i<n/gvl; i++){
v0 = VLSEV_FLOAT( &x[idx], stride_x, gvl );
nonzero_mask = VMFNE( v0, 0, gvl );
v0 = VFABS( v0, gvl );
scale_mask = VMFGT( v0, v_scale, gvl );

// assume scale changes are relatively infrequent

// unclear if the vcpop+branch is actually a win
// since the operations being skipped are predicated anyway
// need profiling to confirm
if( VCPOP(scale_mask, gvl) )
{
v_scale = VFDIV_M( scale_mask, v_scale, v_scale, v0, gvl );
v_scale = VFMUL_M( scale_mask, v_scale, v_scale, v_scale, gvl );
v_ssq = VFMUL_M( scale_mask, v_ssq, v_ssq, v_scale, gvl );
v_scale = VMERGE( v_scale, v0, scale_mask, gvl );
}
v0 = VFDIV_M( nonzero_mask, v0, v0, v_scale, gvl );
v_ssq = VFMACC_M( nonzero_mask, v_ssq, v0, v0, gvl );
idx += inc_x * gvl;
}

// we have gvl elements which we accumulated independently, with independent scales
// we need to combine these
// naive sort so we process small values first to avoid losing information
// could use vector sort extensions where available, but we're dealing with gvl elts at most

FLOAT * out_ssq = alloca(gvl*sizeof(FLOAT));
FLOAT * out_scale = alloca(gvl*sizeof(FLOAT));
VSEV_FLOAT( out_ssq, v_ssq, gvl );
VSEV_FLOAT( out_scale, v_scale, gvl );
for( int a = 0; a < (gvl-1); ++a )
{
int smallest = a;
for( size_t b = a+1; b < gvl; ++b )
if( out_scale[b] < out_scale[smallest] )
smallest = b;
if( smallest != a )
{
FLOAT tmp1 = out_ssq[a];
FLOAT tmp2 = out_scale[a];
out_ssq[a] = out_ssq[smallest];
out_scale[a] = out_scale[smallest];
out_ssq[smallest] = tmp1;
out_scale[smallest] = tmp2;
}
}

int a = 0;
while( a<gvl && out_scale[a] == 0 )
++a;

if( a < gvl )
{
ssq = out_ssq[a];
scale = out_scale[a];
++a;
for( ; a < gvl; ++a )
{
ssq = ssq * ( scale / out_scale[a] ) * ( scale / out_scale[a] ) + out_ssq[a];
scale = out_scale[a];
}
}
}

//finish any tail using scalar ops
i*=gvl*inc_x;
n*=inc_x;
if (n <= 0 || inc_x == 0) return(0.0);
if ( n == 1 ) return( ABS(x[0]) );

BLASLONG i = 0, j = 0;
FLOAT scale = 0.0, ssq = 0.0;

if( inc_x > 0 ){
FLOAT_V_T vr, v0, v_zero;
unsigned int gvl = 0;
FLOAT_V_T_M1 v_res, v_z0;
gvl = VSETVL_MAX;
v_res = VFMVVF_FLOAT_M1(0, gvl);
v_z0 = VFMVVF_FLOAT_M1(0, gvl);
MASK_T mask;
BLASLONG index = 0;

if (inc_x == 1) {
gvl = VSETVL(n);
vr = VFMVVF_FLOAT(0, gvl);
v_zero = VFMVVF_FLOAT(0, gvl);
for (i = 0, j = 0; i < n / gvl; i++) {
v0 = VLEV_FLOAT(&x[j], gvl);
// fabs(vector)
v0 = VFABSV_FLOAT(v0, gvl);
// if scale change
mask = VMFGTVF_FLOAT(v0, scale, gvl);
index = VMFIRSTM(mask, gvl);
if (index == -1) { // no elements greater than scale
if (scale != 0.0) {
v0 = VFDIVVF_FLOAT(v0, scale, gvl);
vr = VFMACCVV_FLOAT_TU(vr, v0, v0, gvl);
}
}
else { // found greater element
// ssq in vector vr: vr[0]
v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl);
// total ssq before current vector
ssq += VFMVFS_FLOAT(v_res);
// find max
v_res = VFREDMAXVS_FLOAT_TU(v_res, v0, v_z0, gvl);
// update ssq before max_index
ssq = ssq * (scale / VFMVFS_FLOAT(v_res)) * (scale / VFMVFS_FLOAT(v_res));
// update scale
scale = VFMVFS_FLOAT(v_res);
// ssq in vector vr
v0 = VFDIVVF_FLOAT(v0, scale, gvl);
vr = VFMACCVV_FLOAT_TU(v_zero, v0, v0, gvl);
}
j += gvl;
}
// ssq in vector vr: vr[0]
v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl);
// total ssq now
ssq += VFMVFS_FLOAT(v_res);

// tail processing
if(j < n){
gvl = VSETVL(n-j);
v0 = VLEV_FLOAT(&x[j], gvl);
// fabs(vector)
v0 = VFABSV_FLOAT(v0, gvl);
// if scale change
mask = VMFGTVF_FLOAT(v0, scale, gvl);
index = VMFIRSTM(mask, gvl);
if (index == -1) { // no elements greater than scale
if(scale != 0.0)
v0 = VFDIVVF_FLOAT(v0, scale, gvl);
} else { // found greater element
// find max
v_res = VFREDMAXVS_FLOAT_TU(v_res, v0, v_z0, gvl);
// update ssq before max_index
ssq = ssq * (scale / VFMVFS_FLOAT(v_res))*(scale / VFMVFS_FLOAT(v_res));
// update scale
scale = VFMVFS_FLOAT(v_res);
v0 = VFDIVVF_FLOAT(v0, scale, gvl);
}
vr = VFMACCVV_FLOAT_TU(v_zero, v0, v0, gvl);
// ssq in vector vr: vr[0]
v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl);
// total ssq now
ssq += VFMVFS_FLOAT(v_res);
}
}
else {
gvl = VSETVL(n);
vr = VFMVVF_FLOAT(0, gvl);
v_zero = VFMVVF_FLOAT(0, gvl);
unsigned int stride_x = inc_x * sizeof(FLOAT);
int idx = 0, inc_v = inc_x * gvl;
for (i = 0, j = 0; i < n / gvl; i++) {
v0 = VLSEV_FLOAT(&x[idx], stride_x, gvl);
// fabs(vector)
v0 = VFABSV_FLOAT(v0, gvl);
// if scale change
mask = VMFGTVF_FLOAT(v0, scale, gvl);
index = VMFIRSTM(mask, gvl);
if (index == -1) {// no elements greater than scale
if(scale != 0.0){
v0 = VFDIVVF_FLOAT(v0, scale, gvl);
vr = VFMACCVV_FLOAT_TU(vr, v0, v0, gvl);
}
}
else { // found greater element
// ssq in vector vr: vr[0]
v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl);
// total ssq before current vector
ssq += VFMVFS_FLOAT(v_res);
// find max
v_res = VFREDMAXVS_FLOAT_TU(v_res, v0, v_z0, gvl);
// update ssq before max_index
ssq = ssq * (scale / VFMVFS_FLOAT(v_res))*(scale / VFMVFS_FLOAT(v_res));
// update scale
scale = VFMVFS_FLOAT(v_res);
// ssq in vector vr
v0 = VFDIVVF_FLOAT(v0, scale, gvl);
vr = VFMACCVV_FLOAT_TU(v_zero, v0, v0, gvl);
}
j += gvl;
idx += inc_v;
}
// ssq in vector vr: vr[0]
v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl);
// total ssq now
ssq += VFMVFS_FLOAT(v_res);

// tail processing
if (j < n) {
gvl = VSETVL(n-j);
v0 = VLSEV_FLOAT(&x[idx], stride_x, gvl);
// fabs(vector)
v0 = VFABSV_FLOAT(v0, gvl);
// if scale change
mask = VMFGTVF_FLOAT(v0, scale, gvl);
index = VMFIRSTM(mask, gvl);
if(index == -1) { // no elements greater than scale
if(scale != 0.0) {
v0 = VFDIVVF_FLOAT(v0, scale, gvl);
vr = VFMACCVV_FLOAT_TU(v_zero, v0, v0, gvl);
}
}
else { // found greater element
// find max
v_res = VFREDMAXVS_FLOAT_TU(v_res, v0, v_z0, gvl);
// update ssq before max_index
ssq = ssq * (scale / VFMVFS_FLOAT(v_res))*(scale / VFMVFS_FLOAT(v_res));
// update scale
scale = VFMVFS_FLOAT(v_res);
v0 = VFDIVVF_FLOAT(v0, scale, gvl);
vr = VFMACCVV_FLOAT_TU(v_zero, v0, v0, gvl);
}
// ssq in vector vr: vr[0]
v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl);
// total ssq now
ssq += VFMVFS_FLOAT(v_res);
}
}
}
else{
// using scalar ops when inc_x < 0
n *= inc_x;
while(abs(i) < abs(n)){
if ( x[i] != 0.0 ){
FLOAT absxi = ABS( x[i] );
if ( scale < absxi ){
ssq = 1 + ssq * ( scale / absxi ) * ( scale / absxi );
scale = absxi ;
}
else{
ssq += ( absxi/scale ) * ( absxi/scale );
}

}

i += inc_x;
if ( x[i] != 0.0 ){
FLOAT absxi = ABS( x[i] );
if ( scale < absxi ){
ssq = 1 + ssq * ( scale / absxi ) * ( scale / absxi );
scale = absxi ;
}
else{
ssq += ( absxi/scale ) * ( absxi/scale );
}

}
i += inc_x;
}

}
return(scale * sqrt(ssq));
}



Loading…
Cancel
Save