Browse Source

Build symbol name from build system variables

pull/5423/head
Martin Kroeker GitHub 5 months ago
parent
commit
08a00326a4
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
2 changed files with 41 additions and 17 deletions
  1. +18
    -6
      kernel/arm64/sgemm_direct_alpha_beta_arm64_sme1.c
  2. +23
    -11
      kernel/arm64/sgemm_direct_arm64_sme1.c

+ 18
- 6
kernel/arm64/sgemm_direct_alpha_beta_arm64_sme1.c View File

@@ -14,9 +14,17 @@
#include <arm_sme.h>
#endif
#if defined(DYNAMIC_ARCH)
#define COMBINE(a,b) a ## b
#define COMBINE2(a,b) COMBINE(a,b)
#define SME1_PREPROCESS_BASE sgemm_direct_sme1_preprocess
#define SME1_PREPROCESS COMBINE2(SME1_PREPROCESS_BASE,TS)
#else
#define SME1_PREPROCESS sgemm_direct_sme1_preprocess
#endif
/* Function prototypes */
extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\
const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess");
extern void SME1_PREPROCESS(uint64_t nbr, uint64_t nbc,\
const float * restrict a, float * a_mod);
/* Function Definitions */
static uint64_t sve_cntw() {
@@ -99,10 +107,11 @@ kernel_2x2(const float *A, const float *B, float *C, size_t shared_dim,
svst1_hor_za32(/*tile*/2, /*slice*/i, pg_c_0, &C[i * ldc]);
svst1_hor_za32(/*tile*/3, /*slice*/i, pg_c_1, &C[i * ldc + svl]);
}
return;
}
__arm_new("za") __arm_locally_streaming
void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
static void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
const float *ba, const float *restrict bb, const float* beta,\
float *restrict C) {
@@ -125,6 +134,7 @@ void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, co
// Block over row dimension of C
for (; row_idx < num_rows; row_idx += row_batch) {
row_batch = MIN(row_batch, num_rows - row_idx);
uint64_t col_idx = 0;
uint64_t col_batch = 2*svl;
@@ -143,7 +153,7 @@ void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, co
#else
void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
const float *ba, const float *restrict bb, const float* beta,\
float *restrict C){}
float *restrict C){fprintf(stderr,"empty sgemm_alpha_beta2x2 should never get called!!!\n");}
#endif
/*void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K,\
@@ -175,7 +185,8 @@ void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict
/* Pre-process the left matrix to make it suitable for
matrix sum of outer-product calculation
*/
sgemm_direct_sme1_preprocess(M, K, A, A_mod);
SME1_PREPROCESS(M, K, A, A_mod);
asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7",
"p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15",
@@ -185,6 +196,7 @@ void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict
"z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");
/* Calculate C = alpha*A*B + beta*C */
sgemm_direct_alpha_beta_sme1_2VLx2VL(M, K, N, &alpha, A_mod, B, &beta, R);
free(A_mod);
@@ -194,6 +206,6 @@ void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\
BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\
float beta, float * __restrict R, BLASLONG strideR){}
float beta, float * __restrict R, BLASLONG strideR){fprintf(stderr,"empty sgemm_direct_alpha_beta should not be called!!!\n");}
#endif

+ 23
- 11
kernel/arm64/sgemm_direct_arm64_sme1.c View File

@@ -8,17 +8,28 @@
#include <inttypes.h>
#include <math.h>
#if defined(HAVE_SME)
#if defined(DYNAMIC_ARCH)
#define COMBINE(a,b) a ## b
#define COMBINE2(a,b) COMBINE(a,b)
#define SME1_PREPROCESS_BASE sgemm_direct_sme1_preprocess
#define SME1_PREPROCESS COMBINE2(SME1_PREPROCESS_BASE,TS)
#define SME1_DIRECT2X2_BASE sgemm_direct_sme1_2VLx2VL
#define SME1_DIRECT2X2 COMBINE2(SME1_DIRECT2X2_BASE,TS)
#else
#define SME1_PREPROCESS sgemm_direct_sme1_preprocess
#define SME1_DIRECT2X2 sgemm_direct_sme1_2VLx2VL
#endif
/* Function prototypes */
extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\
const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess");
extern void sgemm_direct_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n,\
extern void SME1_PREPROCESS(uint64_t nbr, uint64_t nbc,\
const float * restrict a, float * a_mod) ;
extern void SME1_DIRECT2X2(uint64_t m, uint64_t k, uint64_t n,\
const float * matLeft,\
const float * restrict matRight,\
const float * restrict matResult) __asm__("sgemm_direct_sme1_2VLx2VL");
const float * restrict matResult) ;
/* Function Definitions */
uint64_t sve_cntw() {
static uint64_t sve_cntw() {
uint64_t cnt;
asm volatile(
"rdsvl %[res], #1\n"
@@ -39,7 +50,6 @@ void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A,\
uint64_t m_mod, vl_elms;
vl_elms = sve_cntw();
m_mod = ceil((double)M/(double)vl_elms) * vl_elms;
float *A_mod = (float *) malloc(m_mod*K*sizeof(float));
@@ -57,10 +67,11 @@ void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A,\
/* Pre-process the left matrix to make it suitable for
matrix sum of outer-product calculation
*/
sgemm_direct_sme1_preprocess(M, K, A, A_mod);
SME1_PREPROCESS(M, K, A, A_mod);
/* Calculate C = A*B */
sgemm_direct_sme1_2VLx2VL(M, K, N, A_mod, B, R);
fprintf(stderr,"sme direct calling 2x2\n");
SME1_DIRECT2X2(M, K, N, A_mod, B, R);
asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7",
"p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15",
@@ -75,6 +86,7 @@ void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A,\
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A,\
BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\
float * __restrict R, BLASLONG strideR){}
float * __restrict R, BLASLONG strideR){
fprintf(stderr,"EMPTY sgemm_kernel_direct should never be called \n");
}
#endif

Loading…
Cancel
Save