You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gemm_batch.c 12 kB


  1. /*****************************************************************************
  2. Copyright (c) 2020, The OpenBLAS Project
  3. All rights reserved.
  4. Redistribution and use in source and binary forms, with or without
  5. modification, are permitted provided that the following conditions are
  6. met:
  7. 1. Redistributions of source code must retain the above copyright
  8. notice, this list of conditions and the following disclaimer.
  9. 2. Redistributions in binary form must reproduce the above copyright
  10. notice, this list of conditions and the following disclaimer in
  11. the documentation and/or other materials provided with the
  12. distribution.
  13. 3. Neither the name of the OpenBLAS project nor the names of
  14. its contributors may be used to endorse or promote products
  15. derived from this software without specific prior written
  16. permission.
  17. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  18. AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  19. IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  20. ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
  21. LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  22. DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  23. SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  24. CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  25. OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
  26. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  27. **********************************************************************************/
  28. #include <stdio.h>
  29. #include <stdlib.h>
  30. #include "common.h"
  31. void openblas_warning(int verbose, const char * msg);
  32. #ifndef COMPLEX
  33. #ifdef XDOUBLE
  34. #define ERROR_NAME "QGEMM_BATCH "
  35. #elif defined(DOUBLE)
  36. #define ERROR_NAME "DGEMM_BATCH "
  37. #define GEMM_BATCH_THREAD dgemm_batch_thread
  38. #else
  39. #define ERROR_NAME "SGEMM_BATCH "
  40. #define GEMM_BATCH_THREAD sgemm_batch_thread
  41. #endif
  42. #else
  43. #ifdef XDOUBLE
  44. #define ERROR_NAME "XGEMM_BATCH "
  45. #elif defined(DOUBLE)
  46. #define ERROR_NAME "ZGEMM_BATCH "
  47. #define GEMM_BATCH_THREAD zgemm_batch_thread
  48. #else
  49. #define ERROR_NAME "CGEMM_BATCH "
  50. #define GEMM_BATCH_THREAD cgemm_batch_thread
  51. #endif
  52. #endif
  53. static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = {
  54. GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN,
  55. GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT,
  56. GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR,
  57. GEMM_NC, GEMM_TC, GEMM_RC, GEMM_CC,
  58. };
  59. #ifdef SMALL_MATRIX_OPT
  60. #ifndef COMPLEX
  61. static int (*gemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG) = {
  62. #ifndef GEMM3M
  63. GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, NULL, NULL,
  64. GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, NULL, NULL,
  65. #endif
  66. };
  67. static int (*gemm_small_kernel_b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = {
  68. #ifndef GEMM3M
  69. GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, NULL, NULL,
  70. GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, NULL, NULL,
  71. #endif
  72. };
  73. #else
  74. static int (*zgemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG) = {
  75. #ifndef GEMM3M
  76. GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, GEMM_SMALL_KERNEL_RN, GEMM_SMALL_KERNEL_CN,
  77. GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, GEMM_SMALL_KERNEL_RT, GEMM_SMALL_KERNEL_CT,
  78. GEMM_SMALL_KERNEL_NR, GEMM_SMALL_KERNEL_TR, GEMM_SMALL_KERNEL_RR, GEMM_SMALL_KERNEL_CR,
  79. GEMM_SMALL_KERNEL_NC, GEMM_SMALL_KERNEL_TC, GEMM_SMALL_KERNEL_RC, GEMM_SMALL_KERNEL_CC,
  80. #endif
  81. };
  82. static int (*zgemm_small_kernel_b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = {
  83. #ifndef GEMM3M
  84. GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, GEMM_SMALL_KERNEL_B0_RN, GEMM_SMALL_KERNEL_B0_CN,
  85. GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, GEMM_SMALL_KERNEL_B0_RT, GEMM_SMALL_KERNEL_B0_CT,
  86. GEMM_SMALL_KERNEL_B0_NR, GEMM_SMALL_KERNEL_B0_TR, GEMM_SMALL_KERNEL_B0_RR, GEMM_SMALL_KERNEL_B0_CR,
  87. GEMM_SMALL_KERNEL_B0_NC, GEMM_SMALL_KERNEL_B0_TC, GEMM_SMALL_KERNEL_B0_RC, GEMM_SMALL_KERNEL_B0_CC,
  88. #endif
  89. };
  90. #endif
  91. #endif
  92. void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CBLAS_TRANSPOSE * transb_array,
  93. blasint * m_array, blasint * n_array, blasint * k_array,
  94. #ifndef COMPLEX
  95. FLOAT * alpha_array,
  96. FLOAT ** a_array, blasint * lda_array,
  97. FLOAT ** b_array, blasint * ldb_array,
  98. FLOAT * beta_array,
  99. FLOAT ** c_array, blasint * ldc_array, blasint group_count, blasint * group_size) {
  100. #else
  101. void * valpha_array,
  102. void ** va_array, blasint * lda_array,
  103. void ** vb_array, blasint * ldb_array,
  104. void * vbeta_array,
  105. void ** vc_array, blasint * ldc_array, blasint group_count, blasint * group_size) {
  106. FLOAT * alpha_array=(FLOAT *)valpha_array;
  107. FLOAT * beta_array=(FLOAT *)vbeta_array;
  108. FLOAT ** a_array=(FLOAT**)va_array;
  109. FLOAT ** b_array=(FLOAT**)vb_array;
  110. FLOAT ** c_array=(FLOAT**)vc_array;
  111. #endif
  112. blas_arg_t * args_array=NULL;
  113. int mode=0, group_mode=0;
  114. blasint total_num=0;
  115. blasint i=0, j=0, matrix_idx=0, count=0;
  116. int group_transa, group_transb;
  117. BLASLONG group_nrowa, group_nrowb;
  118. blasint info;
  119. void * group_alpha, * group_beta;
  120. BLASLONG group_m, group_n, group_k;
  121. BLASLONG group_lda, group_ldb, group_ldc;
  122. void * group_routine=NULL;
  123. #ifdef SMALL_MATRIX_OPT
  124. void * group_small_matrix_opt_routine=NULL;
  125. #endif
  126. #if defined (SMP) || defined(SMALL_MATRIX_OPT)
  127. double MNK;
  128. #endif
  129. PRINT_DEBUG_CNAME;
  130. for(i=0; i<group_count; i++){
  131. total_num+=group_size[i];
  132. }
  133. args_array=(blas_arg_t *)malloc(total_num * sizeof(blas_arg_t));
  134. if(args_array == NULL){
  135. openblas_warning(0, "memory alloc failed!\n");
  136. exit(1);
  137. }
  138. #ifdef SMP
  139. #ifndef COMPLEX
  140. #ifdef XDOUBLE
  141. mode = BLAS_XDOUBLE | BLAS_REAL;
  142. #elif defined(DOUBLE)
  143. mode = BLAS_DOUBLE | BLAS_REAL;
  144. #else
  145. mode = BLAS_SINGLE | BLAS_REAL;
  146. #endif
  147. #else
  148. #ifdef XDOUBLE
  149. mode = BLAS_XDOUBLE | BLAS_COMPLEX;
  150. #elif defined(DOUBLE)
  151. mode = BLAS_DOUBLE | BLAS_COMPLEX;
  152. #else
  153. mode = BLAS_SINGLE | BLAS_COMPLEX;
  154. #endif
  155. #endif
  156. #endif
  157. for(i=0; i<group_count; matrix_idx+=group_size[i], i++){
  158. group_alpha = (void *)&alpha_array[i * COMPSIZE];
  159. group_beta = (void *)&beta_array[i * COMPSIZE];
  160. group_transa = -1;
  161. group_transb = -1;
  162. info = 0;
  163. if (order == CblasColMajor) {
  164. group_m = m_array[i];
  165. group_n = n_array[i];
  166. group_k = k_array[i];
  167. group_lda = lda_array[i];
  168. group_ldb = ldb_array[i];
  169. group_ldc = ldc_array[i];
  170. if (transa_array[i] == CblasNoTrans) group_transa = 0;
  171. if (transa_array[i] == CblasTrans) group_transa = 1;
  172. #ifndef COMPLEX
  173. if (transa_array[i] == CblasConjNoTrans) group_transa = 0;
  174. if (transa_array[i] == CblasConjTrans) group_transa = 1;
  175. #else
  176. if (transa_array[i] == CblasConjNoTrans) group_transa = 2;
  177. if (transa_array[i] == CblasConjTrans) group_transa = 3;
  178. #endif
  179. if (transb_array[i] == CblasNoTrans) group_transb = 0;
  180. if (transb_array[i] == CblasTrans) group_transb = 1;
  181. #ifndef COMPLEX
  182. if (transb_array[i] == CblasConjNoTrans) group_transb = 0;
  183. if (transb_array[i] == CblasConjTrans) group_transb = 1;
  184. #else
  185. if (transb_array[i] == CblasConjNoTrans) group_transb = 2;
  186. if (transb_array[i] == CblasConjTrans) group_transb = 3;
  187. #endif
  188. group_nrowa = group_m;
  189. if (group_transa & 1) group_nrowa = group_k;
  190. group_nrowb = group_k;
  191. if (group_transb & 1) group_nrowb = group_n;
  192. info=-1;
  193. if (group_ldc < group_m) info = 13;
  194. if (group_ldb < group_nrowb) info = 10;
  195. if (group_lda < group_nrowa) info = 8;
  196. if (group_k < 0) info = 5;
  197. if (group_n < 0) info = 4;
  198. if (group_m < 0) info = 3;
  199. if (group_transb < 0) info = 2;
  200. if (group_transa < 0) info = 1;
  201. }else if (order == CblasRowMajor) {
  202. group_m = n_array[i];
  203. group_n = m_array[i];
  204. group_k = k_array[i];
  205. group_lda = ldb_array[i];
  206. group_ldb = lda_array[i];
  207. group_ldc = ldc_array[i];
  208. if (transb_array[i] == CblasNoTrans) group_transa = 0;
  209. if (transb_array[i] == CblasTrans) group_transa = 1;
  210. #ifndef COMPLEX
  211. if (transb_array[i] == CblasConjNoTrans) group_transa = 0;
  212. if (transb_array[i] == CblasConjTrans) group_transa = 1;
  213. #else
  214. if (transb_array[i] == CblasConjNoTrans) group_transa = 2;
  215. if (transb_array[i] == CblasConjTrans) group_transa = 3;
  216. #endif
  217. if (transa_array[i] == CblasNoTrans) group_transb = 0;
  218. if (transa_array[i] == CblasTrans) group_transb = 1;
  219. #ifndef COMPLEX
  220. if (transa_array[i] == CblasConjNoTrans) group_transb = 0;
  221. if (transa_array[i] == CblasConjTrans) group_transb = 1;
  222. #else
  223. if (transa_array[i] == CblasConjNoTrans) group_transb = 2;
  224. if (transa_array[i] == CblasConjTrans) group_transb = 3;
  225. #endif
  226. group_nrowa = group_m;
  227. if (group_transa & 1) group_nrowa = group_k;
  228. group_nrowb = group_k;
  229. if (group_transb & 1) group_nrowb = group_n;
  230. info=-1;
  231. if (group_ldc < group_m) info = 13;
  232. if (group_ldb < group_nrowb) info = 10;
  233. if (group_lda < group_nrowa) info = 8;
  234. if (group_k < 0) info = 5;
  235. if (group_n < 0) info = 4;
  236. if (group_m < 0) info = 3;
  237. if (group_transb < 0) info = 2;
  238. if (group_transa < 0) info = 1;
  239. }
  240. if (info >= 0) {
  241. BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
  242. free(args_array);
  243. return;
  244. }
  245. if (group_m == 0 || group_n == 0) continue;
  246. group_mode=mode;
  247. #if defined(SMP) || defined(SMALL_MATRIX_OPT)
  248. MNK = (double) group_m * (double) group_n * (double) group_k;
  249. #endif
  250. #ifdef SMALL_MATRIX_OPT
  251. if(MNK <= 100.0*100.0*100.0){
  252. group_routine=NULL;
  253. #if !defined(COMPLEX)
  254. if(*(FLOAT *)(group_beta) == 0.0){
  255. group_mode=mode | BLAS_SMALL_B0_OPT;
  256. group_small_matrix_opt_routine=(void *)(gemm_small_kernel_b0[(group_transb<<2)|group_transa]);
  257. }else{
  258. group_mode=mode | BLAS_SMALL_OPT;
  259. group_small_matrix_opt_routine=(void *)(gemm_small_kernel[(group_transb<<2)|group_transa]);
  260. }
  261. #else
  262. if(((FLOAT *)(group_beta))[0] == 0.0 && ((FLOAT *)(group_beta))[1] == 0.0){
  263. group_mode=mode | BLAS_SMALL_B0_OPT;
  264. group_small_matrix_opt_routine=(void *)(zgemm_small_kernel_b0[(group_transb<<2)|group_transa]);
  265. }else{
  266. group_mode=mode | BLAS_SMALL_OPT;
  267. group_small_matrix_opt_routine=(void *)(zgemm_small_kernel[(group_transb<<2)|group_transa]);
  268. }
  269. #endif
  270. }else{
  271. #endif
  272. group_routine=(void*)(gemm[(group_transb<<2)|group_transa]);
  273. #ifdef SMALL_MATRIX_OPT
  274. }
  275. #endif
  276. for(j=0; j<group_size[i]; j++){
  277. args_array[count].m=group_m;
  278. args_array[count].n=group_n;
  279. args_array[count].k=group_k;
  280. args_array[count].lda=group_lda;
  281. args_array[count].ldb=group_ldb;
  282. args_array[count].ldc=group_ldc;
  283. args_array[count].alpha=group_alpha;
  284. args_array[count].beta=group_beta;
  285. if (order == CblasColMajor) {
  286. args_array[count].a=(a_array[matrix_idx+j]);
  287. args_array[count].b=(b_array[matrix_idx+j]);
  288. }else if(order == CblasRowMajor){
  289. args_array[count].a=(b_array[matrix_idx+j]);
  290. args_array[count].b=(a_array[matrix_idx+j]);
  291. }
  292. args_array[count].c=(c_array[matrix_idx+j]);
  293. args_array[count].routine_mode=group_mode;
  294. args_array[count].routine=group_routine;
  295. #ifdef SMALL_MATRIX_OPT
  296. args_array[count].routine=group_small_matrix_opt_routine;
  297. #endif
  298. count++;
  299. }
  300. }
  301. if(count>0){
  302. GEMM_BATCH_THREAD(args_array,count);
  303. }
  304. free(args_array);
  305. }