Browse Source

Optimize the rotm kernel with RVV intrinsic.

Signed-off-by: tingbo.liao <tingbo.liao@starfivetech.com>
pull/5038/head
tingbo.liao 1 year ago
parent
commit
2afd7410e9
3 changed files with 249 additions and 0 deletions
  1. +1
    -0
      common_riscv64.h
  2. +176
    -0
      interface/rotm.c
  3. +72
    -0
      utest/test_rot.c

+ 1
- 0
common_riscv64.h View File

@@ -91,6 +91,7 @@ static inline int blas_quickdivide(blasint x, blasint y){


#if defined(C910V) || defined(RISCV64_ZVL256B) || defined(RISCV64_ZVL128B) || defined(x280) #if defined(C910V) || defined(RISCV64_ZVL256B) || defined(RISCV64_ZVL128B) || defined(x280)
# include <riscv_vector.h> # include <riscv_vector.h>
#define RISCV_SIMD
#endif #endif


#if defined( __riscv_xtheadc ) && defined( __riscv_v ) && ( __riscv_v <= 7000 ) #if defined( __riscv_xtheadc ) && defined( __riscv_v ) && ( __riscv_v <= 7000 )


+ 176
- 0
interface/rotm.c View File

@@ -3,6 +3,26 @@
#include "functable.h" #include "functable.h"
#endif #endif


#if defined(RISCV_SIMD)
#if !defined(DOUBLE)
#define VSETVL(n) __riscv_vsetvl_e32m8(n)
#define FLOAT_V_T vfloat32m8_t
#define VLSEV_FLOAT __riscv_vlse32_v_f32m8
#define VSSEV_FLOAT __riscv_vsse32_v_f32m8
#define VFMACCVF_FLOAT __riscv_vfmacc_vf_f32m8
#define VFMULVF_FLOAT __riscv_vfmul_vf_f32m8
#define VFMSACVF_FLOAT __riscv_vfmsac_vf_f32m8
#else
#define VSETVL(n) __riscv_vsetvl_e64m8(n)
#define FLOAT_V_T vfloat64m8_t
#define VLSEV_FLOAT __riscv_vlse64_v_f64m8
#define VSSEV_FLOAT __riscv_vsse64_v_f64m8
#define VFMACCVF_FLOAT __riscv_vfmacc_vf_f64m8
#define VFMULVF_FLOAT __riscv_vfmul_vf_f64m8
#define VFMSACVF_FLOAT __riscv_vfmsac_vf_f64m8
#endif
#endif

#ifndef CBLAS #ifndef CBLAS


void NAME(blasint *N, FLOAT *dx, blasint *INCX, FLOAT *dy, blasint *INCY, FLOAT *dparam){ void NAME(blasint *N, FLOAT *dx, blasint *INCX, FLOAT *dy, blasint *INCY, FLOAT *dparam){
@@ -25,6 +45,11 @@ void CNAME(blasint n, FLOAT *dx, blasint incx, FLOAT *dy, blasint incy, FLOAT *d
FLOAT dh11, dh12, dh22, dh21, dflag; FLOAT dh11, dh12, dh22, dh21, dflag;
blasint nsteps; blasint nsteps;


#if defined(RISCV_SIMD)
FLOAT_V_T v_w, v_z__, v_dx, v_dy;
blasint stride, stride_x, stride_y, offset;
#endif

#ifndef CBLAS #ifndef CBLAS
PRINT_DEBUG_CNAME; PRINT_DEBUG_CNAME;
#else #else
@@ -53,6 +78,7 @@ L10:
dh21 = dparam[3]; dh21 = dparam[3];
i__1 = nsteps; i__1 = nsteps;
i__2 = incx; i__2 = incx;
#if !defined(RISCV_SIMD)
for (i__ = 1; i__2 < 0 ? i__ >= i__1 : i__ <= i__1; i__ += i__2) { for (i__ = 1; i__2 < 0 ? i__ >= i__1 : i__ <= i__1; i__ += i__2) {
w = dx[i__]; w = dx[i__];
z__ = dy[i__]; z__ = dy[i__];
@@ -60,12 +86,36 @@ L10:
dy[i__] = w * dh21 + z__; dy[i__] = w * dh21 + z__;
/* L20: */ /* L20: */
} }
#else
if(i__2 < 0){
offset = i__1 - 2;
dx += offset;
dy += offset;
i__1 = -i__1;
i__2 = -i__2;
}
stride = i__2 * sizeof(FLOAT);
n = i__1 / i__2;
for (size_t vl; n > 0; n -= vl, dx += vl*i__2, dy += vl*i__2) {
vl = VSETVL(n);

v_w = VLSEV_FLOAT(&dx[1], stride, vl);
v_z__ = VLSEV_FLOAT(&dy[1], stride, vl);

v_dx = VFMACCVF_FLOAT(v_w, dh12, v_z__, vl);
v_dy = VFMACCVF_FLOAT(v_z__, dh21, v_w, vl);

VSSEV_FLOAT(&dx[1], stride, v_dx, vl);
VSSEV_FLOAT(&dy[1], stride, v_dy, vl);
}
#endif
goto L140; goto L140;
L30: L30:
dh11 = dparam[2]; dh11 = dparam[2];
dh22 = dparam[5]; dh22 = dparam[5];
i__2 = nsteps; i__2 = nsteps;
i__1 = incx; i__1 = incx;
#if !defined(RISCV_SIMD)
for (i__ = 1; i__1 < 0 ? i__ >= i__2 : i__ <= i__2; i__ += i__1) { for (i__ = 1; i__1 < 0 ? i__ >= i__2 : i__ <= i__2; i__ += i__1) {
w = dx[i__]; w = dx[i__];
z__ = dy[i__]; z__ = dy[i__];
@@ -73,6 +123,29 @@ L30:
dy[i__] = -w + dh22 * z__; dy[i__] = -w + dh22 * z__;
/* L40: */ /* L40: */
} }
#else
if(i__1 < 0){
offset = i__2 - 2;
dx += offset;
dy += offset;
i__1 = -i__1;
i__2 = -i__2;
}
stride = i__1 * sizeof(FLOAT);
n = i__2 / i__1;
for (size_t vl; n > 0; n -= vl, dx += vl*i__1, dy += vl*i__1) {
vl = VSETVL(n);

v_w = VLSEV_FLOAT(&dx[1], stride, vl);
v_z__ = VLSEV_FLOAT(&dy[1], stride, vl);

v_dx = VFMACCVF_FLOAT(v_z__, dh11, v_w, vl);
v_dy = VFMSACVF_FLOAT(v_w, dh22, v_z__, vl);

VSSEV_FLOAT(&dx[1], stride, v_dx, vl);
VSSEV_FLOAT(&dy[1], stride, v_dy, vl);
}
#endif
goto L140; goto L140;
L50: L50:
dh11 = dparam[2]; dh11 = dparam[2];
@@ -81,6 +154,7 @@ L50:
dh22 = dparam[5]; dh22 = dparam[5];
i__1 = nsteps; i__1 = nsteps;
i__2 = incx; i__2 = incx;
#if !defined(RISCV_SIMD)
for (i__ = 1; i__2 < 0 ? i__ >= i__1 : i__ <= i__1; i__ += i__2) { for (i__ = 1; i__2 < 0 ? i__ >= i__1 : i__ <= i__1; i__ += i__2) {
w = dx[i__]; w = dx[i__];
z__ = dy[i__]; z__ = dy[i__];
@@ -88,6 +162,31 @@ L50:
dy[i__] = w * dh21 + z__ * dh22; dy[i__] = w * dh21 + z__ * dh22;
/* L60: */ /* L60: */
} }
#else
if(i__2 < 0){
offset = i__1 - 2;
dx += offset;
dy += offset;
i__1 = -i__1;
i__2 = -i__2;
}
stride = i__2 * sizeof(FLOAT);
n = i__1 / i__2;
for (size_t vl; n > 0; n -= vl, dx += vl*i__2, dy += vl*i__2) {
vl = VSETVL(n);

v_w = VLSEV_FLOAT(&dx[1], stride, vl);
v_z__ = VLSEV_FLOAT(&dy[1], stride, vl);

v_dx = VFMULVF_FLOAT(v_w, dh11, vl);
v_dx = VFMACCVF_FLOAT(v_dx, dh12, v_z__, vl);
VSSEV_FLOAT(&dx[1], stride, v_dx, vl);

v_dy = VFMULVF_FLOAT(v_w, dh21, vl);
v_dy = VFMACCVF_FLOAT(v_dy, dh22, v_z__, vl);
VSSEV_FLOAT(&dy[1], stride, v_dy, vl);
}
#endif
goto L140; goto L140;
L70: L70:
kx = 1; kx = 1;
@@ -110,6 +209,7 @@ L80:
dh12 = dparam[4]; dh12 = dparam[4];
dh21 = dparam[3]; dh21 = dparam[3];
i__2 = n; i__2 = n;
#if !defined(RISCV_SIMD)
for (i__ = 1; i__ <= i__2; ++i__) { for (i__ = 1; i__ <= i__2; ++i__) {
w = dx[kx]; w = dx[kx];
z__ = dy[ky]; z__ = dy[ky];
@@ -119,11 +219,36 @@ L80:
ky += incy; ky += incy;
/* L90: */ /* L90: */
} }
#else
if(incx < 0){
incx = -incx;
dx -= n*incx;
}
if(incy < 0){
incy = -incy;
dy -= n*incy;
}
stride_x = incx * sizeof(FLOAT);
stride_y = incy * sizeof(FLOAT);
for (size_t vl; n > 0; n -= vl, dx += vl*incx, dy += vl*incy) {
vl = VSETVL(n);

v_w = VLSEV_FLOAT(&dx[kx], stride_x, vl);
v_z__ = VLSEV_FLOAT(&dy[ky], stride_y, vl);

v_dx = VFMACCVF_FLOAT(v_w, dh12, v_z__, vl);
v_dy = VFMACCVF_FLOAT(v_z__, dh21, v_w, vl);

VSSEV_FLOAT(&dx[kx], stride_x, v_dx, vl);
VSSEV_FLOAT(&dy[ky], stride_y, v_dy, vl);
}
#endif
goto L140; goto L140;
L100: L100:
dh11 = dparam[2]; dh11 = dparam[2];
dh22 = dparam[5]; dh22 = dparam[5];
i__2 = n; i__2 = n;
#if !defined(RISCV_SIMD)
for (i__ = 1; i__ <= i__2; ++i__) { for (i__ = 1; i__ <= i__2; ++i__) {
w = dx[kx]; w = dx[kx];
z__ = dy[ky]; z__ = dy[ky];
@@ -133,8 +258,33 @@ L100:
ky += incy; ky += incy;
/* L110: */ /* L110: */
} }
#else
if(incx < 0){
incx = -incx;
dx -= n*incx;
}
if(incy < 0){
incy = -incy;
dy -= n*incy;
}
stride_x = incx * sizeof(FLOAT);
stride_y = incy * sizeof(FLOAT);
for (size_t vl; n > 0; n -= vl, dx += vl*incx, dy += vl*incy) {
vl = VSETVL(n);

v_w = VLSEV_FLOAT(&dx[kx], stride_x, vl);
v_z__ = VLSEV_FLOAT(&dy[ky], stride_y, vl);

v_dx = VFMACCVF_FLOAT(v_z__, dh11, v_w, vl);
v_dy = VFMSACVF_FLOAT(v_w, dh22, v_z__, vl);

VSSEV_FLOAT(&dx[kx], stride_x, v_dx, vl);
VSSEV_FLOAT(&dy[ky], stride_y, v_dy, vl);
}
#endif
goto L140; goto L140;
L120: L120:
#if !defined(RISCV_SIMD)
dh11 = dparam[2]; dh11 = dparam[2];
dh12 = dparam[4]; dh12 = dparam[4];
dh21 = dparam[3]; dh21 = dparam[3];
@@ -149,6 +299,32 @@ L120:
ky += incy; ky += incy;
/* L130: */ /* L130: */
} }
#else
if(incx < 0){
incx = -incx;
dx -= n*incx;
}
if(incy < 0){
incy = -incy;
dy -= n*incy;
}
stride_x = incx * sizeof(FLOAT);
stride_y = incy * sizeof(FLOAT);
for (size_t vl; n > 0; n -= vl, dx += vl*incx, dy += vl*incy) {
vl = VSETVL(n);

v_w = VLSEV_FLOAT(&dx[kx], stride_x, vl);
v_z__ = VLSEV_FLOAT(&dy[ky], stride_y, vl);

v_dx = VFMULVF_FLOAT(v_w, dh11, vl);
v_dx = VFMACCVF_FLOAT(v_dx, dh12, v_z__, vl);
VSSEV_FLOAT(&dx[kx], stride_x, v_dx, vl);

v_dy = VFMULVF_FLOAT(v_w, dh21, vl);
v_dy = VFMACCVF_FLOAT(v_dy, dh22, v_z__, vl);
VSSEV_FLOAT(&dy[ky], stride_y, v_dy, vl);
}
#endif
L140: L140:
return; return;
} }


+ 72
- 0
utest/test_rot.c View File

@@ -53,6 +53,42 @@ CTEST(rot,drot_inc_0)
ASSERT_DBL_NEAR_TOL(y2[i], y1[i], DOUBLE_EPS); ASSERT_DBL_NEAR_TOL(y2[i], y1[i], DOUBLE_EPS);
} }
} }
CTEST(rot,drot_inc_1)
{
blasint i=0;
blasint N=4,incX=1,incY=1;
double c=1.0,s=1.0;
double x1[]={1.0,3.0,5.0,7.0};
double y1[]={2.0,4.0,6.0,8.0};
double x2[]={3.0,7.0,11.0,15.0};
double y2[]={1.0,1.0,1.0,1.0};

//OpenBLAS
BLASFUNC(drot)(&N,x1,&incX,y1,&incY,&c,&s);

for(i=0; i<N; i++){
ASSERT_DBL_NEAR_TOL(x2[i], x1[i], DOUBLE_EPS);
ASSERT_DBL_NEAR_TOL(y2[i], y1[i], DOUBLE_EPS);
}
}
CTEST(rot,drotm_inc_1)
{
blasint i = 0;
blasint N = 12, incX = 1, incY = 1;
double param[5] = {1.0, 2.0, 3.0, 4.0, 5.0};
double x_actual[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0};
double y_actual[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0};
double x_referece[] = {3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0, 27.0, 30.0, 33.0, 36.0};
double y_referece[] = {4.0, 8.0, 12.0, 16.0, 20.0, 24.0, 28.0, 32.0, 36.0, 40.0, 44.0, 48.0};

//OpenBLAS
BLASFUNC(drotm)(&N, x_actual, &incX, y_actual, &incY, param);

for(i = 0; i < N; i++){
ASSERT_DBL_NEAR_TOL(x_referece[i], x_actual[i], DOUBLE_EPS);
ASSERT_DBL_NEAR_TOL(y_referece[i], y_actual[i], DOUBLE_EPS);
}
}
#endif #endif


#ifdef BUILD_COMPLEX16 #ifdef BUILD_COMPLEX16
@@ -96,6 +132,42 @@ CTEST(rot,srot_inc_0)
ASSERT_DBL_NEAR_TOL(y2[i], y1[i], SINGLE_EPS); ASSERT_DBL_NEAR_TOL(y2[i], y1[i], SINGLE_EPS);
} }
} }
CTEST(rot,srot_inc_1)
{
blasint i=0;
blasint N=4,incX=1,incY=1;
float c=1.0,s=1.0;
float x1[]={1.0,3.0,5.0,7.0};
float y1[]={2.0,4.0,6.0,8.0};
float x2[]={3.0,7.0,11.0,15.0};
float y2[]={1.0,1.0,1.0,1.0};

//OpenBLAS
BLASFUNC(srot)(&N,x1,&incX,y1,&incY,&c,&s);

for(i=0; i<N; i++){
ASSERT_DBL_NEAR_TOL(x2[i], x1[i], SINGLE_EPS);
ASSERT_DBL_NEAR_TOL(y2[i], y1[i], SINGLE_EPS);
}
}
CTEST(rot,srotm_inc_1)
{
blasint i = 0;
blasint N = 12, incX = 1, incY = 1;
float param[5] = {1.0, 2.0, 3.0, 4.0, 5.0};
float x_actual[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0};
float y_actual[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0};
float x_referece[] = {3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0, 27.0, 30.0, 33.0, 36.0};
float y_referece[] = {4.0, 8.0, 12.0, 16.0, 20.0, 24.0, 28.0, 32.0, 36.0, 40.0, 44.0, 48.0};

//OpenBLAS
BLASFUNC(srotm)(&N, x_actual, &incX, y_actual, &incY, param);

for(i = 0; i < N; i++){
ASSERT_DBL_NEAR_TOL(x_referece[i], x_actual[i], SINGLE_EPS);
ASSERT_DBL_NEAR_TOL(y_referece[i], y_actual[i], SINGLE_EPS);
}
}
#endif #endif


#ifdef BUILD_COMPLEX #ifdef BUILD_COMPLEX


Loading…
Cancel
Save