Browse Source

openblas_threads_callback mechanism for changing threads backend

pull/2255/head
Steven G. Johnson 6 years ago
parent
commit
5daf8c9d62
6 changed files with 72 additions and 37 deletions
  1. +30
    -25
      cblas.h
  2. +10
    -4
      common_interface.h
  3. +1
    -0
      driver/others/CMakeLists.txt
  4. +1
    -1
      driver/others/Makefile
  5. +13
    -0
      driver/others/blas_server_callback.c
  6. +17
    -7
      driver/others/blas_server_omp.c

+ 30
- 25
cblas.h View File

@@ -25,6 +25,11 @@ char* openblas_get_config(void);
/*Get the CPU corename on runtime.*/
char* openblas_get_corename(void);

/*Set the threading backend to a custom callback.*/
typedef void (*openblas_dojob_callback)(int thread_num, void *jobdata, void *dojob_data);
typedef void (*openblas_threads_callback)(void *callback_data, openblas_dojob_callback dojob, int numjobs, size_t jobdata_elsize, void *jobdata, void *dojob_data);
void openblas_set_threads_callback(openblas_threads_callback callback, void *callback_data);

/* Get the parallelization type which is used by OpenBLAS */
int openblas_get_parallel(void);
/* OpenBLAS is compiled for sequential use */
@@ -52,7 +57,7 @@ typedef enum CBLAS_UPLO {CblasUpper=121, CblasLower=122} CBLAS_UPLO;
typedef enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132} CBLAS_DIAG;
typedef enum CBLAS_SIDE {CblasLeft=141, CblasRight=142} CBLAS_SIDE;
typedef CBLAS_ORDER CBLAS_LAYOUT;
float cblas_sdsdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST float alpha, OPENBLAS_CONST float *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST float *y, OPENBLAS_CONST blasint incy);
double cblas_dsdot (OPENBLAS_CONST blasint n, OPENBLAS_CONST float *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST float *y, OPENBLAS_CONST blasint incy);
float cblas_sdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST float *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST float *y, OPENBLAS_CONST blasint incy);
@@ -350,32 +355,32 @@ void cblas_caxpby(OPENBLAS_CONST blasint n, OPENBLAS_CONST void *alpha, OPENBLAS

void cblas_zaxpby(OPENBLAS_CONST blasint n, OPENBLAS_CONST void *alpha, OPENBLAS_CONST void *x, OPENBLAS_CONST blasint incx,OPENBLAS_CONST void *beta, void *y, OPENBLAS_CONST blasint incy);

void cblas_somatcopy(OPENBLAS_CONST enum CBLAS_ORDER CORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE CTRANS, OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST float calpha, OPENBLAS_CONST float *a,
OPENBLAS_CONST blasint clda, float *b, OPENBLAS_CONST blasint cldb);
void cblas_somatcopy(OPENBLAS_CONST enum CBLAS_ORDER CORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE CTRANS, OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST float calpha, OPENBLAS_CONST float *a,
OPENBLAS_CONST blasint clda, float *b, OPENBLAS_CONST blasint cldb);
void cblas_domatcopy(OPENBLAS_CONST enum CBLAS_ORDER CORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE CTRANS, OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double calpha, OPENBLAS_CONST double *a,
OPENBLAS_CONST blasint clda, double *b, OPENBLAS_CONST blasint cldb);
void cblas_comatcopy(OPENBLAS_CONST enum CBLAS_ORDER CORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE CTRANS, OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST float* calpha, OPENBLAS_CONST float* a,
OPENBLAS_CONST blasint clda, float*b, OPENBLAS_CONST blasint cldb);
void cblas_zomatcopy(OPENBLAS_CONST enum CBLAS_ORDER CORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE CTRANS, OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double* calpha, OPENBLAS_CONST double* a,
OPENBLAS_CONST blasint clda, double *b, OPENBLAS_CONST blasint cldb);
void cblas_simatcopy(OPENBLAS_CONST enum CBLAS_ORDER CORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE CTRANS, OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST float calpha, float *a,
OPENBLAS_CONST blasint clda, OPENBLAS_CONST blasint cldb);
OPENBLAS_CONST blasint clda, double *b, OPENBLAS_CONST blasint cldb);
void cblas_comatcopy(OPENBLAS_CONST enum CBLAS_ORDER CORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE CTRANS, OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST float* calpha, OPENBLAS_CONST float* a,
OPENBLAS_CONST blasint clda, float*b, OPENBLAS_CONST blasint cldb);
void cblas_zomatcopy(OPENBLAS_CONST enum CBLAS_ORDER CORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE CTRANS, OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double* calpha, OPENBLAS_CONST double* a,
OPENBLAS_CONST blasint clda, double *b, OPENBLAS_CONST blasint cldb);
void cblas_simatcopy(OPENBLAS_CONST enum CBLAS_ORDER CORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE CTRANS, OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST float calpha, float *a,
OPENBLAS_CONST blasint clda, OPENBLAS_CONST blasint cldb);
void cblas_dimatcopy(OPENBLAS_CONST enum CBLAS_ORDER CORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE CTRANS, OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double calpha, double *a,
OPENBLAS_CONST blasint clda, OPENBLAS_CONST blasint cldb);
void cblas_cimatcopy(OPENBLAS_CONST enum CBLAS_ORDER CORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE CTRANS, OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST float* calpha, float* a,
OPENBLAS_CONST blasint clda, OPENBLAS_CONST blasint cldb);
void cblas_zimatcopy(OPENBLAS_CONST enum CBLAS_ORDER CORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE CTRANS, OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double* calpha, double* a,
OPENBLAS_CONST blasint clda, OPENBLAS_CONST blasint cldb);
void cblas_sgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST float calpha, float *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST float cbeta,
float *c, OPENBLAS_CONST blasint cldc);
void cblas_dgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double cbeta,
double *c, OPENBLAS_CONST blasint cldc);
void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST float *calpha, float *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST float *cbeta,
float *c, OPENBLAS_CONST blasint cldc);
void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta,
double *c, OPENBLAS_CONST blasint cldc);
OPENBLAS_CONST blasint clda, OPENBLAS_CONST blasint cldb);
void cblas_cimatcopy(OPENBLAS_CONST enum CBLAS_ORDER CORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE CTRANS, OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST float* calpha, float* a,
OPENBLAS_CONST blasint clda, OPENBLAS_CONST blasint cldb);
void cblas_zimatcopy(OPENBLAS_CONST enum CBLAS_ORDER CORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE CTRANS, OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double* calpha, double* a,
OPENBLAS_CONST blasint clda, OPENBLAS_CONST blasint cldb);
void cblas_sgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST float calpha, float *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST float cbeta,
float *c, OPENBLAS_CONST blasint cldc);
void cblas_dgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double cbeta,
double *c, OPENBLAS_CONST blasint cldc);
void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST float *calpha, float *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST float *cbeta,
float *c, OPENBLAS_CONST blasint cldc);
void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta,
double *c, OPENBLAS_CONST blasint cldc);


#ifdef __cplusplus


+ 10
- 4
common_interface.h View File

@@ -47,6 +47,12 @@ int BLASFUNC(xerbla)(char *, blasint *info, blasint);

void openblas_set_num_threads_(int *);

typedef void (*openblas_dojob_callback)(int thread_num, void *jobdata, void *dojob_data);
typedef void (*openblas_threads_callback)(void *callback_data, openblas_dojob_callback dojob, int numjobs, size_t jobdata_elsize, void *jobdata, void *dojob_data);
void openblas_set_threads_callback(openblas_threads_callback callback, void *callback_data);
extern openblas_threads_callback openblas_threads_callback_;
extern void *openblas_threads_callback_data_;

FLOATRET BLASFUNC(sdot) (blasint *, float *, blasint *, float *, blasint *);
FLOATRET BLASFUNC(sdsdot)(blasint *, float *, float *, blasint *, float *, blasint *);

@@ -761,10 +767,10 @@ void BLASFUNC(dimatcopy) (char *, char *, blasint *, blasint *, double *, do
void BLASFUNC(cimatcopy) (char *, char *, blasint *, blasint *, float *, float *, blasint *, blasint *);
void BLASFUNC(zimatcopy) (char *, char *, blasint *, blasint *, double *, double *, blasint *, blasint *);

void BLASFUNC(sgeadd) (blasint *, blasint *, float *, float *, blasint *, float *, float *, blasint*);
void BLASFUNC(dgeadd) (blasint *, blasint *, double *, double *, blasint *, double *, double *, blasint*);
void BLASFUNC(cgeadd) (blasint *, blasint *, float *, float *, blasint *, float *, float *, blasint*);
void BLASFUNC(zgeadd) (blasint *, blasint *, double *, double *, blasint *, double *, double *, blasint*);
void BLASFUNC(sgeadd) (blasint *, blasint *, float *, float *, blasint *, float *, float *, blasint*);
void BLASFUNC(dgeadd) (blasint *, blasint *, double *, double *, blasint *, double *, double *, blasint*);
void BLASFUNC(cgeadd) (blasint *, blasint *, float *, float *, blasint *, float *, float *, blasint*);
void BLASFUNC(zgeadd) (blasint *, blasint *, double *, double *, blasint *, double *, double *, blasint*);


#ifdef __cplusplus


+ 1
- 0
driver/others/CMakeLists.txt View File

@@ -39,6 +39,7 @@ set(COMMON_SOURCES
openblas_env.c
openblas_get_num_procs.c
openblas_get_num_threads.c
blas_server_callback.c
)

# these need to have NAME/CNAME set, so use GenerateNamedObjects, but don't use standard name mangling


+ 1
- 1
driver/others/Makefile View File

@@ -1,7 +1,7 @@
TOPDIR = ../..
include ../../Makefile.system

COMMONOBJS = memory.$(SUFFIX) xerbla.$(SUFFIX) c_abs.$(SUFFIX) z_abs.$(SUFFIX) openblas_set_num_threads.$(SUFFIX) openblas_get_num_threads.$(SUFFIX) openblas_get_num_procs.$(SUFFIX) openblas_get_config.$(SUFFIX) openblas_get_parallel.$(SUFFIX) openblas_error_handle.$(SUFFIX) openblas_env.$(SUFFIX)
COMMONOBJS = memory.$(SUFFIX) xerbla.$(SUFFIX) c_abs.$(SUFFIX) z_abs.$(SUFFIX) openblas_set_num_threads.$(SUFFIX) openblas_get_num_threads.$(SUFFIX) openblas_get_num_procs.$(SUFFIX) openblas_get_config.$(SUFFIX) openblas_get_parallel.$(SUFFIX) openblas_error_handle.$(SUFFIX) openblas_env.$(SUFFIX) blas_server_callback.$(SUFFIX)

#COMMONOBJS += slamch.$(SUFFIX) slamc3.$(SUFFIX) dlamch.$(SUFFIX) dlamc3.$(SUFFIX)



+ 13
- 0
driver/others/blas_server_callback.c View File

@@ -0,0 +1,13 @@
#include "common.h"

/* global variable to change threading backend from openblas-managed to caller-managed */
openblas_threads_callback openblas_threads_callback_ = 0;
void *openblas_threads_callback_data_ = 0;

/* non-threadsafe function should be called before any other
openblas function to change how threads are managed */
void openblas_set_threads_callback(openblas_threads_callback callback, void *callback_data)
{
openblas_threads_callback_ = callback;
openblas_threads_callback_data_ = callback_data;
}

+ 17
- 7
driver/others/blas_server_omp.c View File

@@ -75,7 +75,8 @@ void goto_set_num_threads(int num_threads) {

blas_cpu_number = num_threads;

omp_set_num_threads(blas_cpu_number);
if (!openblas_threads_callback_)
omp_set_num_threads(blas_cpu_number);

//adjust buffer for each thread
for(i=0; i<MAX_PARALLEL_NUMBER; i++) {
@@ -222,10 +223,9 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
}
}

static void exec_threads(blas_queue_t *queue, int buf_index){

static void exec_threads(int thread_num, blas_queue_t *queue, int *buf_index){
void *buffer, *sa, *sb;
int pos=0, release_flag=0;
int release_flag=0;

buffer = NULL;
sa = queue -> sa;
@@ -238,8 +238,7 @@ static void exec_threads(blas_queue_t *queue, int buf_index){

if ((sa == NULL) && (sb == NULL) && ((queue -> mode & BLAS_PTHREAD) == 0)) {

pos = omp_get_thread_num();
buffer = blas_thread_buffer[buf_index][pos];
buffer = blas_thread_buffer[*buf_index][thread_num];

//fallback
if(buffer==NULL) {
@@ -335,14 +334,25 @@ int exec_blas(BLASLONG num, blas_queue_t *queue){
break;
}

if (openblas_threads_callback_) {
#ifndef USE_SIMPLE_THREADED_LEVEL3
for (i = 0; i < num; i ++)
queue[i].position = i;
#endif
openblas_threads_callback_(openblas_threads_callback_data_, (openblas_dojob_callback) exec_threads, num, sizeof(blas_queue_t), (void*) queue, (void*) &buf_index);
return;
}

#pragma omp parallel for schedule(OMP_SCHED)

#pragma omp parallel for schedule(static)
for (i = 0; i < num; i ++) {

#ifndef USE_SIMPLE_THREADED_LEVEL3
queue[i].position = i;
#endif

exec_threads(&queue[i], buf_index);
exec_threads(omp_get_thread_num(), &queue[i], &buf_index);
}

#if __STDC_VERSION__ >= 201112L


Loading…
Cancel
Save