| @@ -382,6 +382,17 @@ void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint | |||
| void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta, | |||
| double *c, OPENBLAS_CONST blasint cldc); | |||
| void cblas_sgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, | |||
| OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST float ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST float ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); | |||
| void cblas_dgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, | |||
| OPENBLAS_CONST double * alpha_array, OPENBLAS_CONST double ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST double ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST double * beta_array, double ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); | |||
| void cblas_cgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, | |||
| OPENBLAS_CONST void * alpha_array, OPENBLAS_CONST void ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST void ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST void * beta_array, void ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); | |||
| void cblas_zgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, | |||
| OPENBLAS_CONST void * alpha_array, OPENBLAS_CONST void ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST void ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST void * beta_array, void ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); | |||
| #ifdef __cplusplus | |||
| } | |||
| @@ -1919,6 +1919,10 @@ int dgeadd_k(BLASLONG, BLASLONG, double, double*, BLASLONG, double, double *, BL | |||
| int cgeadd_k(BLASLONG, BLASLONG, float, float, float*, BLASLONG, float, float, float *, BLASLONG); | |||
| int zgeadd_k(BLASLONG, BLASLONG, double,double, double*, BLASLONG, double, double, double *, BLASLONG); | |||
| int sgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); | |||
| int dgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); | |||
| int cgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); | |||
| int zgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); | |||
| #ifdef __CUDACC__ | |||
| } | |||
| @@ -2636,7 +2636,17 @@ typedef struct { | |||
| BLASLONG prea, preb, prec, pred; | |||
| #endif | |||
| //for gemm_batch | |||
| void * routine; | |||
| int routine_mode; | |||
| } blas_arg_t; | |||
| #ifdef SMALL_MATRIX_OPT | |||
| #define BLAS_SMALL_OPT 0x10000U | |||
| #define BLAS_SMALL_B0_OPT 0x30000U | |||
| #endif | |||
| #endif | |||
| #ifdef XDOUBLE | |||
| @@ -37,7 +37,7 @@ SBLASOBJS += \ | |||
| ssyrk_UN.$(SUFFIX) ssyrk_UT.$(SUFFIX) ssyrk_LN.$(SUFFIX) ssyrk_LT.$(SUFFIX) \ | |||
| ssyr2k_UN.$(SUFFIX) ssyr2k_UT.$(SUFFIX) ssyr2k_LN.$(SUFFIX) ssyr2k_LT.$(SUFFIX) \ | |||
| ssyrk_kernel_U.$(SUFFIX) ssyrk_kernel_L.$(SUFFIX) \ | |||
| ssyr2k_kernel_U.$(SUFFIX) ssyr2k_kernel_L.$(SUFFIX) | |||
| ssyr2k_kernel_U.$(SUFFIX) ssyr2k_kernel_L.$(SUFFIX) sgemm_batch_thread.$(SUFFIX) | |||
| DBLASOBJS += \ | |||
| dgemm_nn.$(SUFFIX) dgemm_nt.$(SUFFIX) dgemm_tn.$(SUFFIX) dgemm_tt.$(SUFFIX) \ | |||
| @@ -53,7 +53,7 @@ DBLASOBJS += \ | |||
| dsyrk_UN.$(SUFFIX) dsyrk_UT.$(SUFFIX) dsyrk_LN.$(SUFFIX) dsyrk_LT.$(SUFFIX) \ | |||
| dsyr2k_UN.$(SUFFIX) dsyr2k_UT.$(SUFFIX) dsyr2k_LN.$(SUFFIX) dsyr2k_LT.$(SUFFIX) \ | |||
| dsyrk_kernel_U.$(SUFFIX) dsyrk_kernel_L.$(SUFFIX) \ | |||
| dsyr2k_kernel_U.$(SUFFIX) dsyr2k_kernel_L.$(SUFFIX) | |||
| dsyr2k_kernel_U.$(SUFFIX) dsyr2k_kernel_L.$(SUFFIX) dgemm_batch_thread.$(SUFFIX) | |||
| QBLASOBJS += \ | |||
| qgemm_nn.$(SUFFIX) qgemm_nt.$(SUFFIX) qgemm_tn.$(SUFFIX) qgemm_tt.$(SUFFIX) \ | |||
| @@ -103,7 +103,7 @@ CBLASOBJS += \ | |||
| cherk_kernel_LN.$(SUFFIX) cherk_kernel_LC.$(SUFFIX) \ | |||
| csyr2k_kernel_U.$(SUFFIX) csyr2k_kernel_L.$(SUFFIX) \ | |||
| cher2k_kernel_UN.$(SUFFIX) cher2k_kernel_UC.$(SUFFIX) \ | |||
| cher2k_kernel_LN.$(SUFFIX) cher2k_kernel_LC.$(SUFFIX) | |||
| cher2k_kernel_LN.$(SUFFIX) cher2k_kernel_LC.$(SUFFIX) cgemm_batch_thread.$(SUFFIX) | |||
| ZBLASOBJS += \ | |||
| zgemm_nn.$(SUFFIX) zgemm_cn.$(SUFFIX) zgemm_tn.$(SUFFIX) zgemm_nc.$(SUFFIX) \ | |||
| @@ -137,7 +137,7 @@ ZBLASOBJS += \ | |||
| zherk_kernel_LN.$(SUFFIX) zherk_kernel_LC.$(SUFFIX) \ | |||
| zsyr2k_kernel_U.$(SUFFIX) zsyr2k_kernel_L.$(SUFFIX) \ | |||
| zher2k_kernel_UN.$(SUFFIX) zher2k_kernel_UC.$(SUFFIX) \ | |||
| zher2k_kernel_LN.$(SUFFIX) zher2k_kernel_LC.$(SUFFIX) | |||
| zher2k_kernel_LN.$(SUFFIX) zher2k_kernel_LC.$(SUFFIX) zgemm_batch_thread.$(SUFFIX) | |||
| XBLASOBJS += \ | |||
| @@ -2888,6 +2888,18 @@ gemm_thread_variable.$(PSUFFIX) : gemm_thread_variable.c ../../common.h | |||
| beta_thread.$(PSUFFIX) : beta_thread.c ../../common.h | |||
| $(CC) -c $(PFLAGS) $< -o $(@F) | |||
| sgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h | |||
| $(CC) -c $(CFLAGS) $< -o $(@F) | |||
| dgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h | |||
| $(CC) -c $(CFLAGS) $< -o $(@F) | |||
| cgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h | |||
| $(CC) -c $(CFLAGS) $< -o $(@F) | |||
| zgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h | |||
| $(CC) -c $(CFLAGS) $< -o $(@F) | |||
| shgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h | |||
| $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) | |||
| @@ -0,0 +1,151 @@ | |||
| /***************************************************************************** | |||
| Copyright (c) 2020, The OpenBLAS Project | |||
| All rights reserved. | |||
| Redistribution and use in source and binary forms, with or without | |||
| modification, are permitted provided that the following conditions are | |||
| met: | |||
| 1. Redistributions of source code must retain the above copyright | |||
| notice, this list of conditions and the following disclaimer. | |||
| 2. Redistributions in binary form must reproduce the above copyright | |||
| notice, this list of conditions and the following disclaimer in | |||
| the documentation and/or other materials provided with the | |||
| distribution. | |||
| 3. Neither the name of the OpenBLAS project nor the names of | |||
| its contributors may be used to endorse or promote products | |||
| derived from this software without specific prior written | |||
| permission. | |||
| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE | |||
| LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
| DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
| SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
| CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
| OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
| USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| **********************************************************************************/ | |||
| #include "common.h" | |||
| void openblas_warning(int verbose, const char * msg); | |||
| #ifdef SMALL_MATRIX_OPT | |||
| static int inner_small_matrix_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){ | |||
| int routine_mode; | |||
| #ifndef COMPLEX | |||
| int (*gemm_small_kernel)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG); | |||
| int (*gemm_small_kernel_b0)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG); | |||
| #else | |||
| int (*zgemm_small_kernel)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG); | |||
| int (*zgemm_small_kernel_b0)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG); | |||
| FLOAT alpha[2], beta[2]; | |||
| #endif | |||
| routine_mode=args.routine_mode; | |||
| if(routine_mode & BLAS_SMALL_B0_OPT){ | |||
| #ifndef COMPLEX | |||
| gemm_small_kernel_b0=args.routine; | |||
| gemm_small_kernel_b0(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, args.c, args.ldc); | |||
| #else | |||
| zgemm_small_kernel_b0=args.routine; | |||
| alpha[0]=(FLOAT *)(args.alpha)[0]; | |||
| alpha[1]=(FLOAT *)(args.alpha)[1]; | |||
| zgemm_small_kernel_b0(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, args.c, args.ldc); | |||
| #endif | |||
| }else if(routine_mode & BLAS_SMALL_OPT){ | |||
| #ifndef COMPLEX | |||
| gemm_small_kernel=args.routine; | |||
| gemm_small_kernel(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc); | |||
| #else | |||
| zgemm_small_kernel=args.routine; | |||
| alpha[0]=(FLOAT *)(args.alpha)[0]; | |||
| alpha[1]=(FLOAT *)(args.alpha)[1]; | |||
| beta[0]=(FLOAT *)(args.beta)[0]; | |||
| beta[1]=(FLOAT *)(args.beta)[1]; | |||
| zgemm_small_kernel(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, beta[0], beta[1], args.c, args.ldc); | |||
| #endif | |||
| } | |||
| } | |||
| #endif | |||
| int CNAME(blas_arg_t * args_array, BLASLONG nums){ | |||
| XFLOAT *buffer; | |||
| XFLOAT *sa, *sb; | |||
| int nthreads=1; | |||
| int (*routine)(blas_arg_t *, void *, void *, double *, double *, BLASLONG); | |||
| int i=0, j, current_nums; | |||
| #ifdef SMP | |||
| blas_queue_t * queue=NULL; | |||
| #endif | |||
| if(nums <=0 ) return 0; | |||
| buffer = (XFLOAT *)blas_memory_alloc(0); | |||
| sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A); | |||
| sb = (XFLOAT *)(((BLASLONG)sa + ((GEMM_P * GEMM_Q * COMPSIZE * SIZE + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
| #ifdef SMP | |||
| nthreads=num_cpu_avail(3); | |||
| if(nthreads==1){ | |||
| #endif | |||
| //single thread | |||
| for(i=0; i<nums; i++){ | |||
| routine=args_array[i].routine; | |||
| #ifdef SMALL_MATRIX_OPT | |||
| if(args_array[i].routine_mode & BLAS_SMALL_OPT){ | |||
| inner_small_matrix_thread(&args_array[i], NULL, NULL, NULL, NULL, 0); | |||
| }else{ | |||
| #endif | |||
| routine(&args_array[i], NULL, NULL, sa, sb, 0); | |||
| #ifdef SMALL_MATRIX_OPT | |||
| } | |||
| #endif | |||
| } | |||
| #ifdef SMP | |||
| } else { | |||
| //multi thread | |||
| queue=(blas_queue_t *)malloc((nums+1) * sizeof(blas_queue_t)); | |||
| if(queue == NULL){ | |||
| openblas_warning(0, "memory alloc failed!\n"); | |||
| exit(1); | |||
| } | |||
| for(i=0; i<nums; i++){ | |||
| queue[i].args=&args_array[i]; | |||
| queue[i].range_m=NULL; | |||
| queue[i].range_n=NULL; | |||
| queue[i].sa=NULL; | |||
| queue[i].sb=NULL; | |||
| queue[i].next=&queue[i+1]; | |||
| queue[i].mode=args_array[i].routine_mode; | |||
| queue[i].routine=args_array[i].routine; | |||
| #ifdef SMALL_MATRIX_OPT | |||
| if(args_array[i].routine_mode & BLAS_SMALL_OPT){ | |||
| queue[i].routine=inner_small_matrix_thread; | |||
| } | |||
| #endif | |||
| } | |||
| for(i=0; i<nums; i+=nthreads){ | |||
| current_nums=((nums-i)>nthreads)? nthreads: (nums-i); | |||
| queue[i].sa=sa; | |||
| queue[i].sb=sb; | |||
| queue[i+current_nums-1].next=NULL; | |||
| exec_blas(current_nums, &queue[i]); | |||
| } | |||
| free(queue); | |||
| } | |||
| #endif | |||
| blas_memory_free(buffer); | |||
| return 0; | |||
| } | |||
| @@ -81,6 +81,7 @@ | |||
| cblas_ismin, cblas_idmin, cblas_icmin, cblas_izmin, | |||
| cblas_ismax, cblas_idmax, cblas_icmax, cblas_izmax, | |||
| cblas_ssum, cblas_dsum, cblas_scsum, cblas_dzsum, | |||
| cblas_sgemm_batch, cblas_dgemm_batch, cblas_cgemm_batch, cblas_zgemm_batch, | |||
| cblas_xerbla | |||
| ); | |||
| @@ -278,7 +278,7 @@ CSBLAS2OBJS = \ | |||
| CSBLAS3OBJS = \ | |||
| cblas_sgemm.$(SUFFIX) cblas_ssymm.$(SUFFIX) cblas_strmm.$(SUFFIX) cblas_strsm.$(SUFFIX) \ | |||
| cblas_ssyrk.$(SUFFIX) cblas_ssyr2k.$(SUFFIX) cblas_somatcopy.$(SUFFIX) cblas_simatcopy.$(SUFFIX)\ | |||
| cblas_sgeadd.$(SUFFIX) | |||
| cblas_sgeadd.$(SUFFIX) cblas_sgemm_batch.$(SUFFIX) | |||
| ifeq ($(BUILD_HALF),1) | |||
| CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX) | |||
| @@ -300,7 +300,7 @@ CDBLAS2OBJS = \ | |||
| CDBLAS3OBJS += \ | |||
| cblas_dgemm.$(SUFFIX) cblas_dsymm.$(SUFFIX) cblas_dtrmm.$(SUFFIX) cblas_dtrsm.$(SUFFIX) \ | |||
| cblas_dsyrk.$(SUFFIX) cblas_dsyr2k.$(SUFFIX) cblas_domatcopy.$(SUFFIX) cblas_dimatcopy.$(SUFFIX) \ | |||
| cblas_dgeadd.$(SUFFIX) | |||
| cblas_dgeadd.$(SUFFIX) cblas_dgemm_batch.$(SUFFIX) | |||
| CCBLAS1OBJS = \ | |||
| cblas_icamax.$(SUFFIX) cblas_icamin.$(SUFFIX) cblas_scasum.$(SUFFIX) cblas_caxpy.$(SUFFIX) \ | |||
| @@ -325,7 +325,7 @@ CCBLAS3OBJS = \ | |||
| cblas_csyrk.$(SUFFIX) cblas_csyr2k.$(SUFFIX) \ | |||
| cblas_chemm.$(SUFFIX) cblas_cherk.$(SUFFIX) cblas_cher2k.$(SUFFIX) \ | |||
| cblas_comatcopy.$(SUFFIX) cblas_cimatcopy.$(SUFFIX)\ | |||
| cblas_cgeadd.$(SUFFIX) cblas_xerbla.$(SUFFIX) | |||
| cblas_cgeadd.$(SUFFIX) cblas_xerbla.$(SUFFIX) cblas_cgemm_batch.$(SUFFIX) | |||
| @@ -353,7 +353,7 @@ CZBLAS3OBJS = \ | |||
| cblas_zsyrk.$(SUFFIX) cblas_zsyr2k.$(SUFFIX) \ | |||
| cblas_zhemm.$(SUFFIX) cblas_zherk.$(SUFFIX) cblas_zher2k.$(SUFFIX)\ | |||
| cblas_zomatcopy.$(SUFFIX) cblas_zimatcopy.$(SUFFIX) \ | |||
| cblas_zgeadd.$(SUFFIX) | |||
| cblas_zgeadd.$(SUFFIX) cblas_zgemm_batch.$(SUFFIX) | |||
| ifeq ($(SUPPORT_GEMM3M), 1) | |||
| @@ -2236,3 +2236,15 @@ cblas_zgeadd.$(SUFFIX) cblas_zgeadd.$(PSUFFIX) : zgeadd.c | |||
| cblas_xerbla.$(SUFFIX) cblas_xerbla.$(PSUFFIX) : xerbla.c | |||
| $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) | |||
| cblas_sgemm_batch.$(SUFFIX) cblas_sgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h | |||
| $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) | |||
| cblas_dgemm_batch.$(SUFFIX) cblas_dgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h | |||
| $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) | |||
| cblas_cgemm_batch.$(SUFFIX) cblas_cgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h | |||
| $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) | |||
| cblas_zgemm_batch.$(SUFFIX) cblas_zgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h | |||
| $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) | |||
| @@ -0,0 +1,358 @@ | |||
| /***************************************************************************** | |||
| Copyright (c) 2020, The OpenBLAS Project | |||
| All rights reserved. | |||
| Redistribution and use in source and binary forms, with or without | |||
| modification, are permitted provided that the following conditions are | |||
| met: | |||
| 1. Redistributions of source code must retain the above copyright | |||
| notice, this list of conditions and the following disclaimer. | |||
| 2. Redistributions in binary form must reproduce the above copyright | |||
| notice, this list of conditions and the following disclaimer in | |||
| the documentation and/or other materials provided with the | |||
| distribution. | |||
| 3. Neither the name of the OpenBLAS project nor the names of | |||
| its contributors may be used to endorse or promote products | |||
| derived from this software without specific prior written | |||
| permission. | |||
| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE | |||
| LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
| DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
| SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
| CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
| OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
| USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| **********************************************************************************/ | |||
| #include <stdio.h> | |||
| #include <stdlib.h> | |||
| #include "common.h" | |||
| void openblas_warning(int verbose, const char * msg); | |||
| #ifndef COMPLEX | |||
| #ifdef XDOUBLE | |||
| #define ERROR_NAME "QGEMM_BATCH " | |||
| #elif defined(DOUBLE) | |||
| #define ERROR_NAME "DGEMM_BATCH " | |||
| #define GEMM_BATCH_THREAD dgemm_batch_thread | |||
| #else | |||
| #define ERROR_NAME "SGEMM_BATCH " | |||
| #define GEMM_BATCH_THREAD sgemm_batch_thread | |||
| #endif | |||
| #else | |||
| #ifdef XDOUBLE | |||
| #define ERROR_NAME "XGEMM_BATCH " | |||
| #elif defined(DOUBLE) | |||
| #define ERROR_NAME "ZGEMM_BATCH " | |||
| #define GEMM_BATCH_THREAD zgemm_batch_thread | |||
| #else | |||
| #define ERROR_NAME "CGEMM_BATCH " | |||
| #define GEMM_BATCH_THREAD cgemm_batch_thread | |||
| #endif | |||
| #endif | |||
| static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = { | |||
| GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN, | |||
| GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT, | |||
| GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR, | |||
| GEMM_NC, GEMM_TC, GEMM_RC, GEMM_CC, | |||
| }; | |||
| #ifdef SMALL_MATRIX_OPT | |||
| #ifndef COMPLEX | |||
| static int (*gemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG) = { | |||
| #ifndef GEMM3M | |||
| GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, NULL, NULL, | |||
| GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, NULL, NULL, | |||
| #endif | |||
| }; | |||
| static int (*gemm_small_kernel_b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = { | |||
| #ifndef GEMM3M | |||
| GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, NULL, NULL, | |||
| GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, NULL, NULL, | |||
| #endif | |||
| }; | |||
| #else | |||
| static int (*zgemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG) = { | |||
| #ifndef GEMM3M | |||
| GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, GEMM_SMALL_KERNEL_RN, GEMM_SMALL_KERNEL_CN, | |||
| GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, GEMM_SMALL_KERNEL_RT, GEMM_SMALL_KERNEL_CT, | |||
| GEMM_SMALL_KERNEL_NR, GEMM_SMALL_KERNEL_TR, GEMM_SMALL_KERNEL_RR, GEMM_SMALL_KERNEL_CR, | |||
| GEMM_SMALL_KERNEL_NC, GEMM_SMALL_KERNEL_TC, GEMM_SMALL_KERNEL_RC, GEMM_SMALL_KERNEL_CC, | |||
| #endif | |||
| }; | |||
| static int (*zgemm_small_kernel_b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = { | |||
| #ifndef GEMM3M | |||
| GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, GEMM_SMALL_KERNEL_B0_RN, GEMM_SMALL_KERNEL_B0_CN, | |||
| GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, GEMM_SMALL_KERNEL_B0_RT, GEMM_SMALL_KERNEL_B0_CT, | |||
| GEMM_SMALL_KERNEL_B0_NR, GEMM_SMALL_KERNEL_B0_TR, GEMM_SMALL_KERNEL_B0_RR, GEMM_SMALL_KERNEL_B0_CR, | |||
| GEMM_SMALL_KERNEL_B0_NC, GEMM_SMALL_KERNEL_B0_TC, GEMM_SMALL_KERNEL_B0_RC, GEMM_SMALL_KERNEL_B0_CC, | |||
| #endif | |||
| }; | |||
| #endif | |||
| #endif | |||
| void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CBLAS_TRANSPOSE * transb_array, | |||
| blasint * m_array, blasint * n_array, blasint * k_array, | |||
| #ifndef COMPLEX | |||
| FLOAT * alpha_array, | |||
| FLOAT ** a_array, blasint * lda_array, | |||
| FLOAT ** b_array, blasint * ldb_array, | |||
| FLOAT * beta_array, | |||
| FLOAT ** c_array, blasint * ldc_array, blasint group_count, blasint * group_size) { | |||
| #else | |||
| void * valpha_array, | |||
| void ** va_array, blasint * lda_array, | |||
| void ** vb_array, blasint * ldb_array, | |||
| void * vbeta_array, | |||
| void ** vc_array, blasint * ldc_array, blasint group_count, blasint * group_size) { | |||
| FLOAT * alpha_array=(FLOAT *)valpha_array; | |||
| FLOAT * beta_array=(FLOAT *)vbeta_array; | |||
| FLOAT ** a_array=(FLOAT**)va_array; | |||
| FLOAT ** b_array=(FLOAT**)vb_array; | |||
| FLOAT ** c_array=(FLOAT**)vc_array; | |||
| #endif | |||
| blas_arg_t * args_array=NULL; | |||
| int mode=0, group_mode=0; | |||
| blasint total_num=0; | |||
| blasint i=0, j=0, matrix_idx=0, count=0; | |||
| int group_transa, group_transb; | |||
| BLASLONG group_nrowa, group_nrowb; | |||
| blasint info; | |||
| void * group_alpha, * group_beta; | |||
| BLASLONG group_m, group_n, group_k; | |||
| BLASLONG group_lda, group_ldb, group_ldc; | |||
| void * group_routine=NULL; | |||
| #ifdef SMALL_MATRIX_OPT | |||
| void * group_small_matrix_opt_routine=NULL; | |||
| #endif | |||
| #if defined (SMP) || defined(SMALL_MATRIX_OPT) | |||
| double MNK; | |||
| #endif | |||
| PRINT_DEBUG_CNAME; | |||
| for(i=0; i<group_count; i++){ | |||
| total_num+=group_size[i]; | |||
| } | |||
| args_array=(blas_arg_t *)malloc(total_num * sizeof(blas_arg_t)); | |||
| if(args_array == NULL){ | |||
| openblas_warning(0, "memory alloc failed!\n"); | |||
| exit(1); | |||
| } | |||
| #ifdef SMP | |||
| #ifndef COMPLEX | |||
| #ifdef XDOUBLE | |||
| mode = BLAS_XDOUBLE | BLAS_REAL; | |||
| #elif defined(DOUBLE) | |||
| mode = BLAS_DOUBLE | BLAS_REAL; | |||
| #else | |||
| mode = BLAS_SINGLE | BLAS_REAL; | |||
| #endif | |||
| #else | |||
| #ifdef XDOUBLE | |||
| mode = BLAS_XDOUBLE | BLAS_COMPLEX; | |||
| #elif defined(DOUBLE) | |||
| mode = BLAS_DOUBLE | BLAS_COMPLEX; | |||
| #else | |||
| mode = BLAS_SINGLE | BLAS_COMPLEX; | |||
| #endif | |||
| #endif | |||
| #endif | |||
| for(i=0; i<group_count; matrix_idx+=group_size[i], i++){ | |||
| group_alpha = (void *)&alpha_array[i * COMPSIZE]; | |||
| group_beta = (void *)&beta_array[i * COMPSIZE]; | |||
| group_transa = -1; | |||
| group_transb = -1; | |||
| info = 0; | |||
| if (order == CblasColMajor) { | |||
| group_m = m_array[i]; | |||
| group_n = n_array[i]; | |||
| group_k = k_array[i]; | |||
| group_lda = lda_array[i]; | |||
| group_ldb = ldb_array[i]; | |||
| group_ldc = ldc_array[i]; | |||
| if (transa_array[i] == CblasNoTrans) group_transa = 0; | |||
| if (transa_array[i] == CblasTrans) group_transa = 1; | |||
| #ifndef COMPLEX | |||
| if (transa_array[i] == CblasConjNoTrans) group_transa = 0; | |||
| if (transa_array[i] == CblasConjTrans) group_transa = 1; | |||
| #else | |||
| if (transa_array[i] == CblasConjNoTrans) group_transa = 2; | |||
| if (transa_array[i] == CblasConjTrans) group_transa = 3; | |||
| #endif | |||
| if (transb_array[i] == CblasNoTrans) group_transb = 0; | |||
| if (transb_array[i] == CblasTrans) group_transb = 1; | |||
| #ifndef COMPLEX | |||
| if (transb_array[i] == CblasConjNoTrans) group_transb = 0; | |||
| if (transb_array[i] == CblasConjTrans) group_transb = 1; | |||
| #else | |||
| if (transb_array[i] == CblasConjNoTrans) group_transb = 2; | |||
| if (transb_array[i] == CblasConjTrans) group_transb = 3; | |||
| #endif | |||
| group_nrowa = group_m; | |||
| if (group_transa & 1) group_nrowa = group_k; | |||
| group_nrowb = group_k; | |||
| if (group_transb & 1) group_nrowb = group_n; | |||
| info=-1; | |||
| if (group_ldc < group_m) info = 13; | |||
| if (group_ldb < group_nrowb) info = 10; | |||
| if (group_lda < group_nrowa) info = 8; | |||
| if (group_k < 0) info = 5; | |||
| if (group_n < 0) info = 4; | |||
| if (group_m < 0) info = 3; | |||
| if (group_transb < 0) info = 2; | |||
| if (group_transa < 0) info = 1; | |||
| }else if (order == CblasRowMajor) { | |||
| group_m = n_array[i]; | |||
| group_n = m_array[i]; | |||
| group_k = k_array[i]; | |||
| group_lda = ldb_array[i]; | |||
| group_ldb = lda_array[i]; | |||
| group_ldc = ldc_array[i]; | |||
| if (transb_array[i] == CblasNoTrans) group_transa = 0; | |||
| if (transb_array[i] == CblasTrans) group_transa = 1; | |||
| #ifndef COMPLEX | |||
| if (transb_array[i] == CblasConjNoTrans) group_transa = 0; | |||
| if (transb_array[i] == CblasConjTrans) group_transa = 1; | |||
| #else | |||
| if (transb_array[i] == CblasConjNoTrans) group_transa = 2; | |||
| if (transb_array[i] == CblasConjTrans) group_transa = 3; | |||
| #endif | |||
| if (transa_array[i] == CblasNoTrans) group_transb = 0; | |||
| if (transa_array[i] == CblasTrans) group_transb = 1; | |||
| #ifndef COMPLEX | |||
| if (transa_array[i] == CblasConjNoTrans) group_transb = 0; | |||
| if (transa_array[i] == CblasConjTrans) group_transb = 1; | |||
| #else | |||
| if (transa_array[i] == CblasConjNoTrans) group_transb = 2; | |||
| if (transa_array[i] == CblasConjTrans) group_transb = 3; | |||
| #endif | |||
| group_nrowa = group_m; | |||
| if (group_transa & 1) group_nrowa = group_k; | |||
| group_nrowb = group_k; | |||
| if (group_transb & 1) group_nrowb = group_n; | |||
| info=-1; | |||
| if (group_ldc < group_m) info = 13; | |||
| if (group_ldb < group_nrowb) info = 10; | |||
| if (group_lda < group_nrowa) info = 8; | |||
| if (group_k < 0) info = 5; | |||
| if (group_n < 0) info = 4; | |||
| if (group_m < 0) info = 3; | |||
| if (group_transb < 0) info = 2; | |||
| if (group_transa < 0) info = 1; | |||
| } | |||
| if (info >= 0) { | |||
| BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); | |||
| free(args_array); | |||
| return; | |||
| } | |||
| if (group_m == 0 || group_n == 0) continue; | |||
| group_mode=mode; | |||
| #if defined(SMP) || defined(SMALL_MATRIX_OPT) | |||
| MNK = (double) group_m * (double) group_n * (double) group_k; | |||
| #endif | |||
| #ifdef SMALL_MATRIX_OPT | |||
| if(MNK <= 100.0*100.0*100.0){ | |||
| group_routine=NULL; | |||
| #if !defined(COMPLEX) | |||
| if(*(FLOAT *)(group_beta) == 0.0){ | |||
| group_mode=mode | BLAS_SMALL_B0_OPT; | |||
| group_small_matrix_opt_routine=(void *)(gemm_small_kernel_b0[(group_transb<<2)|group_transa]); | |||
| }else{ | |||
| group_mode=mode | BLAS_SMALL_OPT; | |||
| group_small_matrix_opt_routine=(void *)(gemm_small_kernel[(group_transb<<2)|group_transa]); | |||
| } | |||
| #else | |||
| if(((FLOAT *)(group_beta))[0] == 0.0 && ((FLOAT *)(group_beta))[1] == 0.0){ | |||
| group_mode=mode | BLAS_SMALL_B0_OPT; | |||
| group_small_matrix_opt_routine=(void *)(zgemm_small_kernel_b0[(group_transb<<2)|group_transa]); | |||
| }else{ | |||
| group_mode=mode | BLAS_SMALL_OPT; | |||
| group_small_matrix_opt_routine=(void *)(zgemm_small_kernel[(group_transb<<2)|group_transa]); | |||
| } | |||
| #endif | |||
| }else{ | |||
| #endif | |||
| group_routine=(void*)(gemm[(group_transb<<2)|group_transa]); | |||
| #ifdef SMALL_MATRIX_OPT | |||
| } | |||
| #endif | |||
| for(j=0; j<group_size[i]; j++){ | |||
| args_array[count].m=group_m; | |||
| args_array[count].n=group_n; | |||
| args_array[count].k=group_k; | |||
| args_array[count].lda=group_lda; | |||
| args_array[count].ldb=group_ldb; | |||
| args_array[count].ldc=group_ldc; | |||
| args_array[count].alpha=group_alpha; | |||
| args_array[count].beta=group_beta; | |||
| if (order == CblasColMajor) { | |||
| args_array[count].a=(a_array[matrix_idx+j]); | |||
| args_array[count].b=(b_array[matrix_idx+j]); | |||
| }else if(order == CblasRowMajor){ | |||
| args_array[count].a=(b_array[matrix_idx+j]); | |||
| args_array[count].b=(a_array[matrix_idx+j]); | |||
| } | |||
| args_array[count].c=(c_array[matrix_idx+j]); | |||
| args_array[count].routine_mode=group_mode; | |||
| args_array[count].routine=group_routine; | |||
| #ifdef SMALL_MATRIX_OPT | |||
| args_array[count].routine=group_small_matrix_opt_routine; | |||
| #endif | |||
| count++; | |||
| } | |||
| } | |||
| if(count>0){ | |||
| GEMM_BATCH_THREAD(args_array,count); | |||
| } | |||
| free(args_array); | |||
| } | |||