Resurrect and complete cblas_?gemm_batchtags/v0.3.28^2
| @@ -133,7 +133,7 @@ jobs: | |||
| mkdir build | |||
| cd build | |||
| call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" | |||
| cmake -G "Ninja" -DCMAKE_C_COMPILER=cl -DCMAKE_Fortran_COMPILER=flang -DC_LAPACK=1 -DCMAKE_MT=mt -DCMAKE_BUILD_TYPE=Release -DMSVC_STATIC_CRT=ON .. | |||
| cmake -G "Ninja" -DCMAKE_C_COMPILER=cl -DCMAKE_Fortran_COMPILER=flang-new -DC_LAPACK=1 -DCMAKE_MT=mt -DCMAKE_BUILD_TYPE=Release -DMSVC_STATIC_CRT=ON .. | |||
| cmake --build . --config Release | |||
| ctest | |||
| @@ -416,6 +416,18 @@ void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint | |||
| void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta, | |||
| double *c, OPENBLAS_CONST blasint cldc); | |||
| void cblas_sgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, | |||
| OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST float ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST float ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); | |||
| void cblas_dgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, | |||
| OPENBLAS_CONST double * alpha_array, OPENBLAS_CONST double ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST double ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST double * beta_array, double ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); | |||
| void cblas_cgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, | |||
| OPENBLAS_CONST void * alpha_array, OPENBLAS_CONST void ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST void ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST void * beta_array, void ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); | |||
| void cblas_zgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, | |||
| OPENBLAS_CONST void * alpha_array, OPENBLAS_CONST void ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST void ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST void * beta_array, void ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); | |||
| /*** BFLOAT16 and INT8 extensions ***/ | |||
| /* convert float array to BFLOAT16 array by rounding */ | |||
| void cblas_sbstobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST float *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout); | |||
| @@ -431,6 +443,9 @@ void cblas_sbgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum | |||
| void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, | |||
| OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc); | |||
| void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, | |||
| OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif /* __cplusplus */ | |||
| @@ -1937,8 +1937,13 @@ int zimatcopy_k_rtc(BLASLONG, BLASLONG, double, double, double *, BLASLONG); | |||
| int sgeadd_k(BLASLONG, BLASLONG, float, float*, BLASLONG, float, float *, BLASLONG); | |||
| int dgeadd_k(BLASLONG, BLASLONG, double, double*, BLASLONG, double, double *, BLASLONG); | |||
| int cgeadd_k(BLASLONG, BLASLONG, float, float, float*, BLASLONG, float, float, float *, BLASLONG); | |||
| int zgeadd_k(BLASLONG, BLASLONG, double,double, double*, BLASLONG, double, double, double *, BLASLONG); | |||
| int zgeadd_k(BLASLONG, BLASLONG, double,double, double*, BLASLONG, double, double, double *, BLASLONG); | |||
| int sgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); | |||
| int dgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); | |||
| int cgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); | |||
| int zgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); | |||
| int sbgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); | |||
| #ifdef __CUDACC__ | |||
| } | |||
| @@ -2655,9 +2655,20 @@ typedef struct { | |||
| BLASLONG prea, preb, prec, pred; | |||
| #endif | |||
| //for gemm_batch | |||
| void * routine; | |||
| int routine_mode; | |||
| } blas_arg_t; | |||
| #endif | |||
| #ifdef SMALL_MATRIX_OPT | |||
| #define BLAS_SMALL_OPT 0x10000U | |||
| #define BLAS_SMALL_B0_OPT 0x30000U | |||
| #endif | |||
| #ifdef XDOUBLE | |||
| #define TRSV_NUU qtrsv_NUU | |||
| @@ -68,6 +68,8 @@ if (USE_THREAD) | |||
| endif () | |||
| foreach (float_type ${FLOAT_TYPES}) | |||
| GenerateNamedObjects("gemm_batch_thread.c" "" "gemm_batch_thread" 0 "" "" false ${float_type}) | |||
| if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX") | |||
| GenerateCombinationObjects("zherk_kernel.c" "LOWER;CONJ" "U;N" "HERK" 2 "herk_kernel" false ${float_type}) | |||
| # TRANS needs to be set/unset when CONJ is set/unset, so can't use it as a combination | |||
| @@ -37,7 +37,7 @@ SBLASOBJS += \ | |||
| ssyrk_UN.$(SUFFIX) ssyrk_UT.$(SUFFIX) ssyrk_LN.$(SUFFIX) ssyrk_LT.$(SUFFIX) \ | |||
| ssyr2k_UN.$(SUFFIX) ssyr2k_UT.$(SUFFIX) ssyr2k_LN.$(SUFFIX) ssyr2k_LT.$(SUFFIX) \ | |||
| ssyrk_kernel_U.$(SUFFIX) ssyrk_kernel_L.$(SUFFIX) \ | |||
| ssyr2k_kernel_U.$(SUFFIX) ssyr2k_kernel_L.$(SUFFIX) | |||
| ssyr2k_kernel_U.$(SUFFIX) ssyr2k_kernel_L.$(SUFFIX) sgemm_batch_thread.$(SUFFIX) | |||
| DBLASOBJS += \ | |||
| dgemm_nn.$(SUFFIX) dgemm_nt.$(SUFFIX) dgemm_tn.$(SUFFIX) dgemm_tt.$(SUFFIX) \ | |||
| @@ -53,7 +53,7 @@ DBLASOBJS += \ | |||
| dsyrk_UN.$(SUFFIX) dsyrk_UT.$(SUFFIX) dsyrk_LN.$(SUFFIX) dsyrk_LT.$(SUFFIX) \ | |||
| dsyr2k_UN.$(SUFFIX) dsyr2k_UT.$(SUFFIX) dsyr2k_LN.$(SUFFIX) dsyr2k_LT.$(SUFFIX) \ | |||
| dsyrk_kernel_U.$(SUFFIX) dsyrk_kernel_L.$(SUFFIX) \ | |||
| dsyr2k_kernel_U.$(SUFFIX) dsyr2k_kernel_L.$(SUFFIX) | |||
| dsyr2k_kernel_U.$(SUFFIX) dsyr2k_kernel_L.$(SUFFIX) dgemm_batch_thread.$(SUFFIX) | |||
| QBLASOBJS += \ | |||
| qgemm_nn.$(SUFFIX) qgemm_nt.$(SUFFIX) qgemm_tn.$(SUFFIX) qgemm_tt.$(SUFFIX) \ | |||
| @@ -103,7 +103,7 @@ CBLASOBJS += \ | |||
| cherk_kernel_LN.$(SUFFIX) cherk_kernel_LC.$(SUFFIX) \ | |||
| csyr2k_kernel_U.$(SUFFIX) csyr2k_kernel_L.$(SUFFIX) \ | |||
| cher2k_kernel_UN.$(SUFFIX) cher2k_kernel_UC.$(SUFFIX) \ | |||
| cher2k_kernel_LN.$(SUFFIX) cher2k_kernel_LC.$(SUFFIX) | |||
| cher2k_kernel_LN.$(SUFFIX) cher2k_kernel_LC.$(SUFFIX) cgemm_batch_thread.$(SUFFIX) | |||
| ZBLASOBJS += \ | |||
| zgemm_nn.$(SUFFIX) zgemm_cn.$(SUFFIX) zgemm_tn.$(SUFFIX) zgemm_nc.$(SUFFIX) \ | |||
| @@ -137,7 +137,7 @@ ZBLASOBJS += \ | |||
| zherk_kernel_LN.$(SUFFIX) zherk_kernel_LC.$(SUFFIX) \ | |||
| zsyr2k_kernel_U.$(SUFFIX) zsyr2k_kernel_L.$(SUFFIX) \ | |||
| zher2k_kernel_UN.$(SUFFIX) zher2k_kernel_UC.$(SUFFIX) \ | |||
| zher2k_kernel_LN.$(SUFFIX) zher2k_kernel_LC.$(SUFFIX) | |||
| zher2k_kernel_LN.$(SUFFIX) zher2k_kernel_LC.$(SUFFIX) zgemm_batch_thread.$(SUFFIX) | |||
| XBLASOBJS += \ | |||
| @@ -2942,6 +2942,21 @@ gemm_thread_variable.$(PSUFFIX) : gemm_thread_variable.c ../../common.h | |||
| beta_thread.$(PSUFFIX) : beta_thread.c ../../common.h | |||
| $(CC) -c $(PFLAGS) $< -o $(@F) | |||
| sbgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h | |||
| $(CC) -c $(CFLAGS) $< -o $(@F) | |||
| sgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h | |||
| $(CC) -c $(CFLAGS) $< -o $(@F) | |||
| dgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h | |||
| $(CC) -c $(CFLAGS) $< -o $(@F) | |||
| cgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h | |||
| $(CC) -c $(CFLAGS) $< -o $(@F) | |||
| zgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h | |||
| $(CC) -c $(CFLAGS) $< -o $(@F) | |||
| sbgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h | |||
| $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) | |||
| @@ -0,0 +1,156 @@ | |||
| /***************************************************************************** | |||
| Copyright (c) 2020, The OpenBLAS Project | |||
| All rights reserved. | |||
| Redistribution and use in source and binary forms, with or without | |||
| modification, are permitted provided that the following conditions are | |||
| met: | |||
| 1. Redistributions of source code must retain the above copyright | |||
| notice, this list of conditions and the following disclaimer. | |||
| 2. Redistributions in binary form must reproduce the above copyright | |||
| notice, this list of conditions and the following disclaimer in | |||
| the documentation and/or other materials provided with the | |||
| distribution. | |||
| 3. Neither the name of the OpenBLAS project nor the names of | |||
| its contributors may be used to endorse or promote products | |||
| derived from this software without specific prior written | |||
| permission. | |||
| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE | |||
| LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
| DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
| SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
| CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
| OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
| USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| **********************************************************************************/ | |||
| #include "common.h" | |||
| void openblas_warning(int verbose, const char * msg); | |||
| #ifdef SMALL_MATRIX_OPT | |||
| static int inner_small_matrix_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){ | |||
| int routine_mode; | |||
| #ifndef COMPLEX | |||
| int (*gemm_small_kernel)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG); | |||
| int (*gemm_small_kernel_b0)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG); | |||
| #else | |||
| int (*zgemm_small_kernel)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG); | |||
| int (*zgemm_small_kernel_b0)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG); | |||
| FLOAT alpha[2], beta[2]; | |||
| #endif | |||
| routine_mode=args->routine_mode; | |||
| if((routine_mode & BLAS_SMALL_B0_OPT) == BLAS_SMALL_B0_OPT){ | |||
| #ifndef COMPLEX | |||
| gemm_small_kernel_b0=args->routine; | |||
| gemm_small_kernel_b0(args->m, args->n, args->k, args->a, args->lda, *(FLOAT *)(args->alpha), args->b, args->ldb, args->c, args->ldc); | |||
| #else | |||
| zgemm_small_kernel_b0=args->routine; | |||
| alpha[0] = *((FLOAT *)args -> alpha + 0); | |||
| alpha[1] = *((FLOAT *)args -> alpha + 1); | |||
| zgemm_small_kernel_b0(args->m, args->n, args->k, args->a, args->lda, alpha[0], alpha[1], args->b, args->ldb, args->c, args->ldc); | |||
| #endif | |||
| return(0); | |||
| }else if(routine_mode & BLAS_SMALL_OPT){ | |||
| #ifndef COMPLEX | |||
| gemm_small_kernel=args->routine; | |||
| gemm_small_kernel(args->m, args->n, args->k, args->a, args->lda, *(FLOAT *)(args->alpha), args->b, args->ldb, *(FLOAT *)(args->beta), args->c, args->ldc); | |||
| #else | |||
| zgemm_small_kernel=args->routine; | |||
| alpha[0] = *((FLOAT *)args -> alpha + 0); | |||
| alpha[1] = *((FLOAT *)args -> alpha + 1); | |||
| beta[0] = *((FLOAT *)args -> beta + 0); | |||
| beta[1] = *((FLOAT *)args -> beta + 1); | |||
| zgemm_small_kernel(args->m, args->n, args->k, args->a, args->lda, alpha[0], alpha[1], args->b, args->ldb, beta[0], beta[1], args->c, args->ldc); | |||
| #endif | |||
| return(0); | |||
| } | |||
| return(1); | |||
| } | |||
| #endif | |||
| int CNAME(blas_arg_t * args_array, BLASLONG nums){ | |||
| XFLOAT *buffer; | |||
| XFLOAT *sa, *sb; | |||
| int nthreads=1; | |||
| int (*routine)(blas_arg_t *, void *, void *, XFLOAT *, XFLOAT *, BLASLONG); | |||
| int i=0, /*j,*/ current_nums; | |||
| #ifdef SMP | |||
| blas_queue_t * queue=NULL; | |||
| #endif | |||
| if(nums <=0 ) return 0; | |||
| buffer = (XFLOAT *)blas_memory_alloc(0); | |||
| sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A); | |||
| sb = (XFLOAT *)(((BLASLONG)sa + ((GEMM_P * GEMM_Q * COMPSIZE * SIZE + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
| #ifdef SMP | |||
| nthreads=num_cpu_avail(3); | |||
| if(nthreads==1){ | |||
| #endif | |||
| //single thread | |||
| for(i=0; i<nums; i++){ | |||
| routine=args_array[i].routine; | |||
| #ifdef SMALL_MATRIX_OPT | |||
| if(args_array[i].routine_mode & BLAS_SMALL_OPT){ | |||
| inner_small_matrix_thread(&args_array[i], NULL, NULL, NULL, NULL, 0); | |||
| }else{ | |||
| #endif | |||
| routine(&args_array[i], NULL, NULL, sa, sb, 0); | |||
| #ifdef SMALL_MATRIX_OPT | |||
| } | |||
| #endif | |||
| } | |||
| #ifdef SMP | |||
| } else { | |||
| //multi thread | |||
| queue=(blas_queue_t *)malloc((nums+1) * sizeof(blas_queue_t)); | |||
| if(queue == NULL){ | |||
| openblas_warning(0, "memory alloc failed!\n"); | |||
| return(1); | |||
| } | |||
| for(i=0; i<nums; i++){ | |||
| queue[i].args=&args_array[i]; | |||
| queue[i].range_m=NULL; | |||
| queue[i].range_n=NULL; | |||
| queue[i].sa=NULL; | |||
| queue[i].sb=NULL; | |||
| queue[i].next=&queue[i+1]; | |||
| queue[i].mode=args_array[i].routine_mode; | |||
| queue[i].routine=args_array[i].routine; | |||
| #ifdef SMALL_MATRIX_OPT | |||
| if((args_array[i].routine_mode & BLAS_SMALL_B0_OPT) || (args_array[i].routine_mode & BLAS_SMALL_OPT)){ | |||
| queue[i].routine=inner_small_matrix_thread; | |||
| } | |||
| #endif | |||
| } | |||
| for(i=0; i<nums; i+=nthreads){ | |||
| current_nums=((nums-i)>nthreads)? nthreads: (nums-i); | |||
| queue[i].sa=sa; | |||
| queue[i].sb=sb; | |||
| queue[i+current_nums-1].next=NULL; | |||
| exec_blas(current_nums, &queue[i]); | |||
| } | |||
| free(queue); | |||
| } | |||
| #endif | |||
| blas_memory_free(buffer); | |||
| return 0; | |||
| } | |||
| @@ -60,7 +60,7 @@ cblasobjsc=" | |||
| cblas_ctbsv cblas_ctpmv cblas_ctpsv cblas_ctrmm cblas_ctrmv cblas_ctrsm cblas_ctrsv | |||
| cblas_scnrm2 cblas_scasum cblas_cgemmt | |||
| cblas_icamax cblas_icamin cblas_icmin cblas_icmax cblas_scsum cblas_cimatcopy cblas_comatcopy | |||
| cblas_caxpyc cblas_crotg cblas_csrot cblas_scamax cblas_scamin | |||
| cblas_caxpyc cblas_crotg cblas_csrot cblas_scamax cblas_scamin cblas_cgemm_batch | |||
| " | |||
| cblasobjsd=" | |||
| cblas_dasum cblas_daxpy cblas_dcopy cblas_ddot | |||
| @@ -70,7 +70,7 @@ cblasobjsd=" | |||
| cblas_dsyr2k cblas_dsyr cblas_dsyrk cblas_dtbmv cblas_dtbsv cblas_dtpmv cblas_dtpsv | |||
| cblas_dtrmm cblas_dtrmv cblas_dtrsm cblas_dtrsv cblas_daxpby cblas_dgeadd cblas_dgemmt | |||
| cblas_idamax cblas_idamin cblas_idmin cblas_idmax cblas_dsum cblas_dimatcopy cblas_domatcopy | |||
| cblas_damax cblas_damin | |||
| cblas_damax cblas_damin cblas_dgemm_batch | |||
| " | |||
| cblasobjss=" | |||
| @@ -82,7 +82,7 @@ cblasobjss=" | |||
| cblas_stbmv cblas_stbsv cblas_stpmv cblas_stpsv cblas_strmm cblas_strmv cblas_strsm | |||
| cblas_strsv cblas_sgeadd cblas_sgemmt | |||
| cblas_isamax cblas_isamin cblas_ismin cblas_ismax cblas_ssum cblas_simatcopy cblas_somatcopy | |||
| cblas_samax cblas_samin | |||
| cblas_samax cblas_samin cblas_sgemm_batch | |||
| " | |||
| cblasobjsz=" | |||
| @@ -94,12 +94,12 @@ cblasobjsz=" | |||
| cblas_ztrsv cblas_cdotc_sub cblas_cdotu_sub cblas_zdotc_sub cblas_zdotu_sub | |||
| cblas_zaxpby cblas_zgeadd cblas_zgemmt | |||
| cblas_izamax cblas_izamin cblas_izmin cblas_izmax cblas_dzsum cblas_zimatcopy cblas_zomatcopy | |||
| cblas_zaxpyc cblas_zdrot cblas_zrotg cblas_dzamax cblas_dzamin | |||
| cblas_zaxpyc cblas_zdrot cblas_zrotg cblas_dzamax cblas_dzamin cblas_zgemm_batch | |||
| " | |||
| cblasobjs="cblas_xerbla" | |||
| bfcblasobjs="cblas_sbgemm cblas_sbgemv cblas_sbdot cblas_sbstobf16 cblas_sbdtobf16 cblas_sbf16tos cblas_dbf16tod" | |||
| bfcblasobjs="cblas_sbgemm cblas_sbgemv cblas_sbdot cblas_sbstobf16 cblas_sbdtobf16 cblas_sbf16tos cblas_dbf16tod cblas_sbgemm_batch" | |||
| exblasobjs=" | |||
| qamax qamin qasum qaxpy qcabs1 qcopy qdot qgbmv qgemm | |||
| @@ -97,6 +97,9 @@ foreach (CBLAS_FLAG ${CBLAS_FLAGS}) | |||
| #sdsdot, dsdot | |||
| if (BUILD_SINGLE OR BUILD_DOUBLE) | |||
| GenerateNamedObjects("sdsdot.c" "" "sdsdot" ${CBLAS_FLAG} "" "" true "SINGLE") | |||
| if(CBLAS_FLAG EQUAL 1) | |||
| GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" ${CBLAS_FLAG} "" "" false) | |||
| endif () | |||
| endif () | |||
| if (BUILD_DOUBLE) | |||
| GenerateNamedObjects("dsdot.c" "" "dsdot" ${CBLAS_FLAG} "" "" true "SINGLE") | |||
| @@ -125,6 +128,9 @@ if (BUILD_BFLOAT16) | |||
| GenerateNamedObjects("tobf16.c" "DOUBLE_PREC" "sbdtobf16" ${CBLAS_FLAG} "" "" true "BFLOAT16") | |||
| GenerateNamedObjects("bf16to.c" "SINGLE_PREC" "sbf16tos" ${CBLAS_FLAG} "" "" true "BFLOAT16") | |||
| GenerateNamedObjects("bf16to.c" "DOUBLE_PREC" "dbf16tod" ${CBLAS_FLAG} "" "" true "BFLOAT16") | |||
| if(CBLAS_FLAG EQUAL 1) | |||
| GenerateNamedObjects("gemm_batch.c" "" "sbgemm_batch" ${CBLAS_FLAG} "" "" true "BFLOAT16") | |||
| endif () | |||
| endif () | |||
| # complex-specific sources | |||
| @@ -154,6 +160,9 @@ foreach (float_type ${FLOAT_TYPES}) | |||
| GenerateNamedObjects("max.c" "USE_ABS" "scamax" ${CBLAS_FLAG} "" "" true "COMPLEX") | |||
| GenerateNamedObjects("asum.c" "" "scasum" ${CBLAS_FLAG} "" "" true "COMPLEX") | |||
| GenerateNamedObjects("sum.c" "" "scsum" ${CBLAS_FLAG} "" "" true "COMPLEX") | |||
| if(CBLAS_FLAG EQUAL 1) | |||
| GenerateNamedObjects("gemm_batch.c" "" "cgemm_batch" ${CBLAS_FLAG} "" "" true "COMPLEX") | |||
| endif () | |||
| endif () | |||
| if (${float_type} STREQUAL "ZCOMPLEX") | |||
| GenerateNamedObjects("zscal.c" "SSCAL" "dscal" ${CBLAS_FLAG} "" "" false "ZCOMPLEX") | |||
| @@ -163,6 +172,9 @@ foreach (float_type ${FLOAT_TYPES}) | |||
| GenerateNamedObjects("max.c" "USE_ABS" "dzamax" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") | |||
| GenerateNamedObjects("asum.c" "" "dzasum" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") | |||
| GenerateNamedObjects("sum.c" "" "dzsum" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") | |||
| if(CBLAS_FLAG EQUAL 1) | |||
| GenerateNamedObjects("gemm_batch.c" "" "zgemm_batch" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") | |||
| endif () | |||
| endif () | |||
| endforeach () | |||
| @@ -212,6 +224,7 @@ if ( BUILD_COMPLEX AND NOT BUILD_SINGLE) | |||
| GenerateNamedObjects("nrm2.c" "" "nrm2" 0 "" "" false "SINGLE") | |||
| GenerateNamedObjects("gemv.c" "" "gemv" 0 "" "" false "SINGLE") | |||
| GenerateNamedObjects("gemm.c" "" "gemm" 0 "" "" false "SINGLE") | |||
| GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 1 "" "" false "SINGLE") | |||
| GenerateNamedObjects("asum.c" "" "asum" 0 "" "" false "SINGLE") | |||
| GenerateNamedObjects("swap.c" "" "swap" 0 "" "" false "SINGLE") | |||
| GenerateNamedObjects("axpy.c" "" "axpy" 0 "" "" false "SINGLE") | |||
| @@ -225,6 +238,7 @@ if ( BUILD_COMPLEX16 AND NOT BUILD_DOUBLE) | |||
| GenerateNamedObjects("nrm2.c" "" "nrm2" 0 "" "" false "DOUBLE") | |||
| GenerateNamedObjects("gemv.c" "" "gemv" 0 "" "" false "DOUBLE") | |||
| GenerateNamedObjects("gemm.c" "" "gemm" 0 "" "" false "DOUBLE") | |||
| GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 1 "" "" false "DOUBLE") | |||
| GenerateNamedObjects("asum.c" "" "asum" 0 "" "" false "DOUBLE") | |||
| GenerateNamedObjects("swap.c" "" "swap" 0 "" "" false "DOUBLE") | |||
| GenerateNamedObjects("axpy.c" "" "axpy" 0 "" "" false "DOUBLE") | |||
| @@ -282,12 +282,12 @@ CSBLAS2OBJS = \ | |||
| CSBLAS3OBJS = \ | |||
| cblas_sgemm.$(SUFFIX) cblas_ssymm.$(SUFFIX) cblas_strmm.$(SUFFIX) cblas_strsm.$(SUFFIX) \ | |||
| cblas_ssyrk.$(SUFFIX) cblas_ssyr2k.$(SUFFIX) cblas_somatcopy.$(SUFFIX) cblas_simatcopy.$(SUFFIX)\ | |||
| cblas_sgeadd.$(SUFFIX) cblas_sgemmt.$(SUFFIX) | |||
| cblas_sgeadd.$(SUFFIX) cblas_sgemmt.$(SUFFIX) cblas_sgemm_batch.$(SUFFIX) | |||
| ifeq ($(BUILD_BFLOAT16),1) | |||
| CSBBLAS1OBJS = cblas_sbdot.$(SUFFIX) | |||
| CSBBLAS2OBJS = cblas_sbgemv.$(SUFFIX) | |||
| CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) cblas_sbgemmt.$(SUFFIX) | |||
| CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) cblas_sbgemmt.$(SUFFIX) cblas_sbgemm_batch.$(SUFFIX) | |||
| CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX) | |||
| endif | |||
| @@ -308,7 +308,7 @@ CDBLAS2OBJS = \ | |||
| CDBLAS3OBJS += \ | |||
| cblas_dgemm.$(SUFFIX) cblas_dsymm.$(SUFFIX) cblas_dtrmm.$(SUFFIX) cblas_dtrsm.$(SUFFIX) \ | |||
| cblas_dsyrk.$(SUFFIX) cblas_dsyr2k.$(SUFFIX) cblas_domatcopy.$(SUFFIX) cblas_dimatcopy.$(SUFFIX) \ | |||
| cblas_dgeadd.$(SUFFIX) cblas_dgemmt.$(SUFFIX) | |||
| cblas_dgeadd.$(SUFFIX) cblas_dgemmt.$(SUFFIX) cblas_dgemm_batch.$(SUFFIX) | |||
| CCBLAS1OBJS = \ | |||
| cblas_icamax.$(SUFFIX) cblas_icamin.$(SUFFIX) cblas_scasum.$(SUFFIX) cblas_caxpy.$(SUFFIX) \ | |||
| @@ -333,7 +333,7 @@ CCBLAS3OBJS = \ | |||
| cblas_csyrk.$(SUFFIX) cblas_csyr2k.$(SUFFIX) \ | |||
| cblas_chemm.$(SUFFIX) cblas_cherk.$(SUFFIX) cblas_cher2k.$(SUFFIX) \ | |||
| cblas_comatcopy.$(SUFFIX) cblas_cimatcopy.$(SUFFIX)\ | |||
| cblas_cgeadd.$(SUFFIX) cblas_cgemmt.$(SUFFIX) | |||
| cblas_cgeadd.$(SUFFIX) cblas_cgemmt.$(SUFFIX) cblas_cgemm_batch.$(SUFFIX) | |||
| CXERBLAOBJ = \ | |||
| cblas_xerbla.$(SUFFIX) | |||
| @@ -364,7 +364,7 @@ CZBLAS3OBJS = \ | |||
| cblas_zsyrk.$(SUFFIX) cblas_zsyr2k.$(SUFFIX) \ | |||
| cblas_zhemm.$(SUFFIX) cblas_zherk.$(SUFFIX) cblas_zher2k.$(SUFFIX)\ | |||
| cblas_zomatcopy.$(SUFFIX) cblas_zimatcopy.$(SUFFIX) \ | |||
| cblas_zgeadd.$(SUFFIX) cblas_zgemmt.$(SUFFIX) | |||
| cblas_zgeadd.$(SUFFIX) cblas_zgemmt.$(SUFFIX) cblas_zgemm_batch.$(SUFFIX) | |||
| ifeq ($(SUPPORT_GEMM3M), 1) | |||
| @@ -2419,3 +2419,17 @@ cblas_zgeadd.$(SUFFIX) cblas_zgeadd.$(PSUFFIX) : zgeadd.c | |||
| cblas_xerbla.$(SUFFIX) cblas_xerbla.$(PSUFFIX) : xerbla.c | |||
| $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) | |||
| cblas_sbgemm_batch.$(SUFFIX) cblas_sbgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h | |||
| $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) | |||
| cblas_sgemm_batch.$(SUFFIX) cblas_sgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h | |||
| $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) | |||
| cblas_dgemm_batch.$(SUFFIX) cblas_dgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h | |||
| $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) | |||
| cblas_cgemm_batch.$(SUFFIX) cblas_cgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h | |||
| $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) | |||
| cblas_zgemm_batch.$(SUFFIX) cblas_zgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h | |||
| $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) | |||
| @@ -0,0 +1,372 @@ | |||
| /***************************************************************************** | |||
| Copyright (c) 2020, The OpenBLAS Project | |||
| All rights reserved. | |||
| Redistribution and use in source and binary forms, with or without | |||
| modification, are permitted provided that the following conditions are | |||
| met: | |||
| 1. Redistributions of source code must retain the above copyright | |||
| notice, this list of conditions and the following disclaimer. | |||
| 2. Redistributions in binary form must reproduce the above copyright | |||
| notice, this list of conditions and the following disclaimer in | |||
| the documentation and/or other materials provided with the | |||
| distribution. | |||
| 3. Neither the name of the OpenBLAS project nor the names of | |||
| its contributors may be used to endorse or promote products | |||
| derived from this software without specific prior written | |||
| permission. | |||
| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE | |||
| LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
| DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
| SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
| CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
| OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
| USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| **********************************************************************************/ | |||
| #include <stdio.h> | |||
| #include <stdlib.h> | |||
| #include "common.h" | |||
| void openblas_warning(int verbose, const char * msg); | |||
| #ifndef COMPLEX | |||
| #ifdef XDOUBLE | |||
| #define ERROR_NAME "QGEMM_BATCH " | |||
| #elif defined(DOUBLE) | |||
| #define ERROR_NAME "DGEMM_BATCH " | |||
| #define GEMM_BATCH_THREAD dgemm_batch_thread | |||
| #else | |||
| #define ERROR_NAME "SGEMM_BATCH " | |||
| #define GEMM_BATCH_THREAD sgemm_batch_thread | |||
| #endif | |||
| #else | |||
| #ifdef XDOUBLE | |||
| #define ERROR_NAME "XGEMM_BATCH " | |||
| #elif defined(DOUBLE) | |||
| #define ERROR_NAME "ZGEMM_BATCH " | |||
| #define GEMM_BATCH_THREAD zgemm_batch_thread | |||
| #else | |||
| #define ERROR_NAME "CGEMM_BATCH " | |||
| #define GEMM_BATCH_THREAD cgemm_batch_thread | |||
| #endif | |||
| #endif | |||
| static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = { | |||
| GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN, | |||
| GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT, | |||
| GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR, | |||
| GEMM_NC, GEMM_TC, GEMM_RC, GEMM_CC, | |||
| }; | |||
| #if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE) | |||
| #define USE_SMALL_MATRIX_OPT 1 | |||
| #else | |||
| #define USE_SMALL_MATRIX_OPT 0 | |||
| #endif | |||
| #if USE_SMALL_MATRIX_OPT | |||
| #ifndef DYNAMIC_ARCH | |||
| #define SMALL_KERNEL_ADDR(table, idx) ((void *)(table[idx])) | |||
| #else | |||
| #define SMALL_KERNEL_ADDR(table, idx) ((void *)(*(uintptr_t *)((char *)gotoblas + (size_t)(table[idx])))) | |||
| #endif | |||
| #ifndef COMPLEX | |||
| static size_t gemm_small_kernel[] = { | |||
| GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, 0, 0, | |||
| GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, 0, 0, | |||
| }; | |||
| static size_t gemm_small_kernel_b0[] = { | |||
| GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, 0, 0, | |||
| GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, 0, 0, | |||
| }; | |||
| #define GEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel_b0, (idx)) | |||
| #define GEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel, (idx)) | |||
| #else | |||
| static size_t zgemm_small_kernel[] = { | |||
| GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, GEMM_SMALL_KERNEL_RN, GEMM_SMALL_KERNEL_CN, | |||
| GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, GEMM_SMALL_KERNEL_RT, GEMM_SMALL_KERNEL_CT, | |||
| GEMM_SMALL_KERNEL_NR, GEMM_SMALL_KERNEL_TR, GEMM_SMALL_KERNEL_RR, GEMM_SMALL_KERNEL_CR, | |||
| GEMM_SMALL_KERNEL_NC, GEMM_SMALL_KERNEL_TC, GEMM_SMALL_KERNEL_RC, GEMM_SMALL_KERNEL_CC, | |||
| }; | |||
| static size_t zgemm_small_kernel_b0[] = { | |||
| GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, GEMM_SMALL_KERNEL_B0_RN, GEMM_SMALL_KERNEL_B0_CN, | |||
| GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, GEMM_SMALL_KERNEL_B0_RT, GEMM_SMALL_KERNEL_B0_CT, | |||
| GEMM_SMALL_KERNEL_B0_NR, GEMM_SMALL_KERNEL_B0_TR, GEMM_SMALL_KERNEL_B0_RR, GEMM_SMALL_KERNEL_B0_CR, | |||
| GEMM_SMALL_KERNEL_B0_NC, GEMM_SMALL_KERNEL_B0_TC, GEMM_SMALL_KERNEL_B0_RC, GEMM_SMALL_KERNEL_B0_CC, | |||
| }; | |||
| #define ZGEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel, (idx)) | |||
| #define ZGEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel_b0, (idx)) | |||
| #endif | |||
| #endif | |||
| void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CBLAS_TRANSPOSE * transb_array, | |||
| blasint * m_array, blasint * n_array, blasint * k_array, | |||
| #ifndef COMPLEX | |||
| FLOAT * alpha_array, | |||
| FLOAT ** a_array, blasint * lda_array, | |||
| FLOAT ** b_array, blasint * ldb_array, | |||
| FLOAT * beta_array, | |||
| FLOAT ** c_array, blasint * ldc_array, blasint group_count, blasint * group_size) { | |||
| #else | |||
| void * valpha_array, | |||
| void ** va_array, blasint * lda_array, | |||
| void ** vb_array, blasint * ldb_array, | |||
| void * vbeta_array, | |||
| void ** vc_array, blasint * ldc_array, blasint group_count, blasint * group_size) { | |||
| FLOAT * alpha_array=(FLOAT *)valpha_array; | |||
| FLOAT * beta_array=(FLOAT *)vbeta_array; | |||
| FLOAT ** a_array=(FLOAT**)va_array; | |||
| FLOAT ** b_array=(FLOAT**)vb_array; | |||
| FLOAT ** c_array=(FLOAT**)vc_array; | |||
| #endif | |||
| blas_arg_t * args_array=NULL; | |||
| int mode=0, group_mode=0; | |||
| blasint total_num=0; | |||
| blasint i=0, j=0, matrix_idx=0, count=0; | |||
| int group_transa, group_transb; | |||
| BLASLONG group_nrowa, group_nrowb; | |||
| blasint info; | |||
| void * group_alpha, * group_beta; | |||
| BLASLONG group_m, group_n, group_k; | |||
| BLASLONG group_lda, group_ldb, group_ldc; | |||
| void * group_routine=NULL; | |||
| #ifdef SMALL_MATRIX_OPT | |||
| void * group_small_matrix_opt_routine=NULL; | |||
| #endif | |||
| #if defined (SMP) || defined(SMALL_MATRIX_OPT) | |||
| double MNK; | |||
| #endif | |||
| PRINT_DEBUG_CNAME; | |||
| for(i=0; i<group_count; i++){ | |||
| total_num+=group_size[i]; | |||
| } | |||
| args_array=(blas_arg_t *)malloc(total_num * sizeof(blas_arg_t)); | |||
| if(args_array == NULL){ | |||
| openblas_warning(0, "memory alloc failed!\n"); | |||
| return; | |||
| } | |||
| #ifdef SMP | |||
| #ifndef COMPLEX | |||
| #ifdef XDOUBLE | |||
| mode = BLAS_XDOUBLE | BLAS_REAL; | |||
| #elif defined(DOUBLE) | |||
| mode = BLAS_DOUBLE | BLAS_REAL; | |||
| #else | |||
| mode = BLAS_SINGLE | BLAS_REAL; | |||
| #endif | |||
| #else | |||
| #ifdef XDOUBLE | |||
| mode = BLAS_XDOUBLE | BLAS_COMPLEX; | |||
| #elif defined(DOUBLE) | |||
| mode = BLAS_DOUBLE | BLAS_COMPLEX; | |||
| #else | |||
| mode = BLAS_SINGLE | BLAS_COMPLEX; | |||
| #endif | |||
| #endif | |||
| #endif | |||
| for(i=0; i<group_count; matrix_idx+=group_size[i], i++){ | |||
| group_alpha = (void *)&alpha_array[i * COMPSIZE]; | |||
| group_beta = (void *)&beta_array[i * COMPSIZE]; | |||
| group_m = group_n = group_k = 0; | |||
| group_lda = group_ldb = group_ldc = 0; | |||
| group_transa = -1; | |||
| group_transb = -1; | |||
| info = 0; | |||
| if (order == CblasColMajor) { | |||
| group_m = m_array[i]; | |||
| group_n = n_array[i]; | |||
| group_k = k_array[i]; | |||
| group_lda = lda_array[i]; | |||
| group_ldb = ldb_array[i]; | |||
| group_ldc = ldc_array[i]; | |||
| if (transa_array[i] == CblasNoTrans) group_transa = 0; | |||
| if (transa_array[i] == CblasTrans) group_transa = 1; | |||
| #ifndef COMPLEX | |||
| if (transa_array[i] == CblasConjNoTrans) group_transa = 0; | |||
| if (transa_array[i] == CblasConjTrans) group_transa = 1; | |||
| #else | |||
| if (transa_array[i] == CblasConjNoTrans) group_transa = 2; | |||
| if (transa_array[i] == CblasConjTrans) group_transa = 3; | |||
| #endif | |||
| if (transb_array[i] == CblasNoTrans) group_transb = 0; | |||
| if (transb_array[i] == CblasTrans) group_transb = 1; | |||
| #ifndef COMPLEX | |||
| if (transb_array[i] == CblasConjNoTrans) group_transb = 0; | |||
| if (transb_array[i] == CblasConjTrans) group_transb = 1; | |||
| #else | |||
| if (transb_array[i] == CblasConjNoTrans) group_transb = 2; | |||
| if (transb_array[i] == CblasConjTrans) group_transb = 3; | |||
| #endif | |||
| group_nrowa = group_m; | |||
| if (group_transa & 1) group_nrowa = group_k; | |||
| group_nrowb = group_k; | |||
| if (group_transb & 1) group_nrowb = group_n; | |||
| info=-1; | |||
| if (group_ldc < group_m) info = 13; | |||
| if (group_ldb < group_nrowb) info = 10; | |||
| if (group_lda < group_nrowa) info = 8; | |||
| if (group_k < 0) info = 5; | |||
| if (group_n < 0) info = 4; | |||
| if (group_m < 0) info = 3; | |||
| if (group_transb < 0) info = 2; | |||
| if (group_transa < 0) info = 1; | |||
| }else if (order == CblasRowMajor) { | |||
| group_m = n_array[i]; | |||
| group_n = m_array[i]; | |||
| group_k = k_array[i]; | |||
| group_lda = ldb_array[i]; | |||
| group_ldb = lda_array[i]; | |||
| group_ldc = ldc_array[i]; | |||
| if (transb_array[i] == CblasNoTrans) group_transa = 0; | |||
| if (transb_array[i] == CblasTrans) group_transa = 1; | |||
| #ifndef COMPLEX | |||
| if (transb_array[i] == CblasConjNoTrans) group_transa = 0; | |||
| if (transb_array[i] == CblasConjTrans) group_transa = 1; | |||
| #else | |||
| if (transb_array[i] == CblasConjNoTrans) group_transa = 2; | |||
| if (transb_array[i] == CblasConjTrans) group_transa = 3; | |||
| #endif | |||
| if (transa_array[i] == CblasNoTrans) group_transb = 0; | |||
| if (transa_array[i] == CblasTrans) group_transb = 1; | |||
| #ifndef COMPLEX | |||
| if (transa_array[i] == CblasConjNoTrans) group_transb = 0; | |||
| if (transa_array[i] == CblasConjTrans) group_transb = 1; | |||
| #else | |||
| if (transa_array[i] == CblasConjNoTrans) group_transb = 2; | |||
| if (transa_array[i] == CblasConjTrans) group_transb = 3; | |||
| #endif | |||
| group_nrowa = group_m; | |||
| if (group_transa & 1) group_nrowa = group_k; | |||
| group_nrowb = group_k; | |||
| if (group_transb & 1) group_nrowb = group_n; | |||
| info=-1; | |||
| if (group_ldc < group_m) info = 13; | |||
| if (group_ldb < group_nrowb) info = 10; | |||
| if (group_lda < group_nrowa) info = 8; | |||
| if (group_k < 0) info = 5; | |||
| if (group_n < 0) info = 4; | |||
| if (group_m < 0) info = 3; | |||
| if (group_transb < 0) info = 2; | |||
| if (group_transa < 0) info = 1; | |||
| } | |||
| if (info >= 0) { | |||
| BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); | |||
| free(args_array); | |||
| return; | |||
| } | |||
| if (group_m == 0 || group_n == 0) continue; | |||
| group_mode=mode; | |||
| #if defined(SMP) || defined(SMALL_MATRIX_OPT) | |||
| MNK = (double) group_m * (double) group_n * (double) group_k; | |||
| #endif | |||
| #ifdef SMALL_MATRIX_OPT | |||
| if (MNK <= 100.0*100.0*100.0){ | |||
| group_routine=NULL; | |||
| #if !defined(COMPLEX) | |||
| if(*(FLOAT *)(group_beta) == 0.0){ | |||
| group_mode=mode | BLAS_SMALL_B0_OPT; | |||
| group_small_matrix_opt_routine=(void *)(gemm_small_kernel_b0[(group_transb<<2)|group_transa]); | |||
| }else{ | |||
| group_mode=mode | BLAS_SMALL_OPT; | |||
| group_small_matrix_opt_routine=(void *)(gemm_small_kernel[(group_transb<<2)|group_transa]); | |||
| } | |||
| #else | |||
| if(((FLOAT *)(group_beta))[0] == 0.0 && ((FLOAT *)(group_beta))[1] == 0.0){ | |||
| group_mode=mode | BLAS_SMALL_B0_OPT; | |||
| group_small_matrix_opt_routine=(void *)(zgemm_small_kernel_b0[(group_transb<<2)|group_transa]); | |||
| }else{ | |||
| group_mode=mode | BLAS_SMALL_OPT; | |||
| group_small_matrix_opt_routine=(void *)(zgemm_small_kernel[(group_transb<<2)|group_transa]); | |||
| } | |||
| #endif | |||
| }else{ | |||
| #endif | |||
| group_routine=(void*)(gemm[(group_transb<<2)|group_transa]); | |||
| #ifdef SMALL_MATRIX_OPT | |||
| } | |||
| #endif | |||
| for(j=0; j<group_size[i]; j++){ | |||
| args_array[count].m=group_m; | |||
| args_array[count].n=group_n; | |||
| args_array[count].k=group_k; | |||
| args_array[count].lda=group_lda; | |||
| args_array[count].ldb=group_ldb; | |||
| args_array[count].ldc=group_ldc; | |||
| args_array[count].alpha=group_alpha; | |||
| args_array[count].beta=group_beta; | |||
| if (order == CblasColMajor) { | |||
| args_array[count].a=(a_array[matrix_idx+j]); | |||
| args_array[count].b=(b_array[matrix_idx+j]); | |||
| }else if(order == CblasRowMajor){ | |||
| args_array[count].a=(b_array[matrix_idx+j]); | |||
| args_array[count].b=(a_array[matrix_idx+j]); | |||
| } | |||
| args_array[count].c=(c_array[matrix_idx+j]); | |||
| args_array[count].routine_mode=group_mode; | |||
| args_array[count].routine=group_routine; | |||
| #ifdef SMALL_MATRIX_OPT | |||
| if (!group_routine) | |||
| args_array[count].routine=group_small_matrix_opt_routine; | |||
| #endif | |||
| count++; | |||
| } | |||
| } | |||
| if(count>0){ | |||
| GEMM_BATCH_THREAD(args_array,count); | |||
| } | |||
| free(args_array); | |||
| } | |||