You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

csytrf.c 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. #include "relapack.h"
  2. #if XSYTRF_ALLOW_MALLOC
  3. #include <stdlib.h>
  4. #endif
  5. static void RELAPACK_csytrf_rec(const char *, const blasint *, const blasint *, blasint *,
  6. float *, const blasint *, blasint *, float *, const blasint *, blasint *);
  7. /** CSYTRF computes the factorization of a complex symmetric matrix A using the Bunch-Kaufman diagonal pivoting method.
  8. *
  9. * This routine is functionally equivalent to LAPACK's csytrf.
  10. * For details on its interface, see
  11. * http://www.netlib.org/lapack/explore-html/d5/d21/csytrf_8f.html
  12. * */
  13. void RELAPACK_csytrf(
  14. const char *uplo, const blasint *n,
  15. float *A, const blasint *ldA, blasint *ipiv,
  16. float *Work, const blasint *lWork, blasint *info
  17. ) {
  18. // Required work size
  19. const blasint cleanlWork = *n * (*n / 2);
  20. blasint minlWork = cleanlWork;
  21. #if XSYTRF_ALLOW_MALLOC
  22. minlWork = 1;
  23. #endif
  24. // Check arguments
  25. const blasint lower = LAPACK(lsame)(uplo, "L");
  26. const blasint upper = LAPACK(lsame)(uplo, "U");
  27. *info = 0;
  28. if (!lower && !upper)
  29. *info = -1;
  30. else if (*n < 0)
  31. *info = -2;
  32. else if (*ldA < MAX(1, *n))
  33. *info = -4;
  34. else if (*lWork < minlWork && *lWork != -1)
  35. *info = -7;
  36. else if (*lWork == -1) {
  37. // Work size query
  38. *Work = cleanlWork;
  39. return;
  40. }
  41. // Ensure Work size
  42. float *cleanWork = Work;
  43. #if XSYTRF_ALLOW_MALLOC
  44. if (!*info && *lWork < cleanlWork) {
  45. cleanWork = malloc(cleanlWork * 2 * sizeof(float));
  46. if (!cleanWork)
  47. *info = -7;
  48. }
  49. #endif
  50. if (*info) {
  51. const blasint minfo = -*info;
  52. LAPACK(xerbla)("CSYTRF", &minfo, strlen("CSYTRF"));
  53. return;
  54. }
  55. // Clean char * arguments
  56. const char cleanuplo = lower ? 'L' : 'U';
  57. // Dummy arguments
  58. blasint nout;
  59. // Recursive kernel
  60. RELAPACK_csytrf_rec(&cleanuplo, n, n, &nout, A, ldA, ipiv, cleanWork, n, info);
  61. #if XSYTRF_ALLOW_MALLOC
  62. if (cleanWork != Work)
  63. free(cleanWork);
  64. #endif
  65. }
  66. /** csytrf's recursive compute kernel */
  67. static void RELAPACK_csytrf_rec(
  68. const char *uplo, const blasint *n_full, const blasint *n, blasint *n_out,
  69. float *A, const blasint *ldA, blasint *ipiv,
  70. float *Work, const blasint *ldWork, blasint *info
  71. ) {
  72. // top recursion level?
  73. const blasint top = *n_full == *n;
  74. if (*n <= MAX(CROSSOVER_CSYTRF, 3)) {
  75. // Unblocked
  76. if (top) {
  77. LAPACK(csytf2)(uplo, n, A, ldA, ipiv, info);
  78. *n_out = *n;
  79. } else
  80. RELAPACK_csytrf_rec2(uplo, n_full, n, n_out, A, ldA, ipiv, Work, ldWork, info);
  81. return;
  82. }
  83. blasint info1, info2;
  84. // Constants
  85. const float ONE[] = { 1., 0. };
  86. const float MONE[] = { -1., 0. };
  87. const blasint iONE[] = { 1 };
  88. // Loop iterator
  89. blasint i;
  90. const blasint n_rest = *n_full - *n;
  91. if (*uplo == 'L') {
  92. // Splitting (setup)
  93. blasint n1 = CREC_SPLIT(*n);
  94. blasint n2 = *n - n1;
  95. // Work_L *
  96. float *const Work_L = Work;
  97. // recursion(A_L)
  98. blasint n1_out;
  99. RELAPACK_csytrf_rec(uplo, n_full, &n1, &n1_out, A, ldA, ipiv, Work_L, ldWork, &info1);
  100. n1 = n1_out;
  101. // Splitting (continued)
  102. n2 = *n - n1;
  103. const blasint n_full2 = *n_full - n1;
  104. // * *
  105. // A_BL A_BR
  106. // A_BL_B A_BR_B
  107. float *const A_BL = A + 2 * n1;
  108. float *const A_BR = A + 2 * *ldA * n1 + 2 * n1;
  109. float *const A_BL_B = A + 2 * *n;
  110. float *const A_BR_B = A + 2 * *ldA * n1 + 2 * *n;
  111. // * *
  112. // Work_BL Work_BR
  113. // * *
  114. // (top recursion level: use Work as Work_BR)
  115. float *const Work_BL = Work + 2 * n1;
  116. float *const Work_BR = top ? Work : Work + 2 * *ldWork * n1 + 2 * n1;
  117. const blasint ldWork_BR = top ? n2 : *ldWork;
  118. // ipiv_T
  119. // ipiv_B
  120. blasint *const ipiv_B = ipiv + n1;
  121. // A_BR = A_BR - A_BL Work_BL'
  122. RELAPACK_cgemmt(uplo, "N", "T", &n2, &n1, MONE, A_BL, ldA, Work_BL, ldWork, ONE, A_BR, ldA);
  123. BLAS(cgemm)("N", "T", &n_rest, &n2, &n1, MONE, A_BL_B, ldA, Work_BL, ldWork, ONE, A_BR_B, ldA);
  124. // recursion(A_BR)
  125. blasint n2_out;
  126. RELAPACK_csytrf_rec(uplo, &n_full2, &n2, &n2_out, A_BR, ldA, ipiv_B, Work_BR, &ldWork_BR, &info2);
  127. if (n2_out != n2) {
  128. // undo 1 column of updates
  129. const blasint n_restp1 = n_rest + 1;
  130. // last column of A_BR
  131. float *const A_BR_r = A_BR + 2 * *ldA * n2_out + 2 * n2_out;
  132. // last row of A_BL
  133. float *const A_BL_b = A_BL + 2 * n2_out;
  134. // last row of Work_BL
  135. float *const Work_BL_b = Work_BL + 2 * n2_out;
  136. // A_BR_r = A_BR_r + A_BL_b Work_BL_b'
  137. BLAS(cgemv)("N", &n_restp1, &n1, ONE, A_BL_b, ldA, Work_BL_b, ldWork, ONE, A_BR_r, iONE);
  138. }
  139. n2 = n2_out;
  140. // shift pivots
  141. for (i = 0; i < n2; i++)
  142. if (ipiv_B[i] > 0)
  143. ipiv_B[i] += n1;
  144. else
  145. ipiv_B[i] -= n1;
  146. *info = info1 || info2;
  147. *n_out = n1 + n2;
  148. } else {
  149. // Splitting (setup)
  150. blasint n2 = CREC_SPLIT(*n);
  151. blasint n1 = *n - n2;
  152. // * Work_R
  153. // (top recursion level: use Work as Work_R)
  154. float *const Work_R = top ? Work : Work + 2 * *ldWork * n1;
  155. // recursion(A_R)
  156. blasint n2_out;
  157. RELAPACK_csytrf_rec(uplo, n_full, &n2, &n2_out, A, ldA, ipiv, Work_R, ldWork, &info2);
  158. const blasint n2_diff = n2 - n2_out;
  159. n2 = n2_out;
  160. // Splitting (continued)
  161. n1 = *n - n2;
  162. const blasint n_full1 = *n_full - n2;
  163. // * A_TL_T A_TR_T
  164. // * A_TL A_TR
  165. // * * *
  166. float *const A_TL_T = A + 2 * *ldA * n_rest;
  167. float *const A_TR_T = A + 2 * *ldA * (n_rest + n1);
  168. float *const A_TL = A + 2 * *ldA * n_rest + 2 * n_rest;
  169. float *const A_TR = A + 2 * *ldA * (n_rest + n1) + 2 * n_rest;
  170. // Work_L *
  171. // * Work_TR
  172. // * *
  173. // (top recursion level: Work_R was Work)
  174. float *const Work_L = Work;
  175. float *const Work_TR = Work + 2 * *ldWork * (top ? n2_diff : n1) + 2 * n_rest;
  176. const blasint ldWork_L = top ? n1 : *ldWork;
  177. // A_TL = A_TL - A_TR Work_TR'
  178. RELAPACK_cgemmt(uplo, "N", "T", &n1, &n2, MONE, A_TR, ldA, Work_TR, ldWork, ONE, A_TL, ldA);
  179. BLAS(cgemm)("N", "T", &n_rest, &n1, &n2, MONE, A_TR_T, ldA, Work_TR, ldWork, ONE, A_TL_T, ldA);
  180. // recursion(A_TL)
  181. blasint n1_out;
  182. RELAPACK_csytrf_rec(uplo, &n_full1, &n1, &n1_out, A, ldA, ipiv, Work_L, &ldWork_L, &info1);
  183. if (n1_out != n1) {
  184. // undo 1 column of updates
  185. const blasint n_restp1 = n_rest + 1;
  186. // A_TL_T_l = A_TL_T_l + A_TR_T Work_TR_t'
  187. BLAS(cgemv)("N", &n_restp1, &n2, ONE, A_TR_T, ldA, Work_TR, ldWork, ONE, A_TL_T, iONE);
  188. }
  189. n1 = n1_out;
  190. *info = info2 || info1;
  191. *n_out = n1 + n2;
  192. }
  193. }