You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cgetrf.c 3.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. #include "relapack.h"
  2. static void RELAPACK_cgetrf_rec(const int *, const int *, float *,
  3. const int *, int *, int *);
  4. /** CGETRF computes an LU factorization of a general M-by-N matrix A using partial pivoting with row interchanges.
  5. *
  6. * This routine is functionally equivalent to LAPACK's cgetrf.
  7. * For details on its interface, see
  8. * http://www.netlib.org/lapack/explore-html/d9/dfb/cgetrf_8f.html
  9. */
  10. void RELAPACK_cgetrf(
  11. const int *m, const int *n,
  12. float *A, const int *ldA, int *ipiv,
  13. int *info
  14. ) {
  15. // Check arguments
  16. *info = 0;
  17. if (*m < 0)
  18. *info = -1;
  19. else if (*n < 0)
  20. *info = -2;
  21. else if (*ldA < MAX(1, *n))
  22. *info = -4;
  23. if (*info) {
  24. const int minfo = -*info;
  25. LAPACK(xerbla)("CGETRF", &minfo);
  26. return;
  27. }
  28. const int sn = MIN(*m, *n);
  29. RELAPACK_cgetrf_rec(m, &sn, A, ldA, ipiv, info);
  30. // Right remainder
  31. if (*m < *n) {
  32. // Constants
  33. const float ONE[] = { 1., 0. };
  34. const int iONE[] = { 1 };
  35. // Splitting
  36. const int rn = *n - *m;
  37. // A_L A_R
  38. const float *const A_L = A;
  39. float *const A_R = A + 2 * *ldA * *m;
  40. // A_R = apply(ipiv, A_R)
  41. LAPACK(claswp)(&rn, A_R, ldA, iONE, m, ipiv, iONE);
  42. // A_R = A_L \ A_R
  43. BLAS(ctrsm)("L", "L", "N", "U", m, &rn, ONE, A_L, ldA, A_R, ldA);
  44. }
  45. }
  46. /** cgetrf's recursive compute kernel */
  47. static void RELAPACK_cgetrf_rec(
  48. const int *m, const int *n,
  49. float *A, const int *ldA, int *ipiv,
  50. int *info
  51. ) {
  52. if (*n <= MAX(CROSSOVER_CGETRF, 1)) {
  53. // Unblocked
  54. LAPACK(cgetf2)(m, n, A, ldA, ipiv, info);
  55. return;
  56. }
  57. // Constants
  58. const float ONE[] = { 1., 0. };
  59. const float MONE[] = { -1., 0. };
  60. const int iONE[] = { 1 };
  61. // Splitting
  62. const int n1 = CREC_SPLIT(*n);
  63. const int n2 = *n - n1;
  64. const int m2 = *m - n1;
  65. // A_L A_R
  66. float *const A_L = A;
  67. float *const A_R = A + 2 * *ldA * n1;
  68. // A_TL A_TR
  69. // A_BL A_BR
  70. float *const A_TL = A;
  71. float *const A_TR = A + 2 * *ldA * n1;
  72. float *const A_BL = A + 2 * n1;
  73. float *const A_BR = A + 2 * *ldA * n1 + 2 * n1;
  74. // ipiv_T
  75. // ipiv_B
  76. int *const ipiv_T = ipiv;
  77. int *const ipiv_B = ipiv + n1;
  78. // recursion(A_L, ipiv_T)
  79. RELAPACK_cgetrf_rec(m, &n1, A_L, ldA, ipiv_T, info);
  80. // apply pivots to A_R
  81. LAPACK(claswp)(&n2, A_R, ldA, iONE, &n1, ipiv_T, iONE);
  82. // A_TR = A_TL \ A_TR
  83. BLAS(ctrsm)("L", "L", "N", "U", &n1, &n2, ONE, A_TL, ldA, A_TR, ldA);
  84. // A_BR = A_BR - A_BL * A_TR
  85. BLAS(cgemm)("N", "N", &m2, &n2, &n1, MONE, A_BL, ldA, A_TR, ldA, ONE, A_BR, ldA);
  86. // recursion(A_BR, ipiv_B)
  87. RELAPACK_cgetrf_rec(&m2, &n2, A_BR, ldA, ipiv_B, info);
  88. if (*info)
  89. *info += n1;
  90. // apply pivots to A_BL
  91. LAPACK(claswp)(&n1, A_BL, ldA, iONE, &n2, ipiv_B, iONE);
  92. // shift pivots
  93. int i;
  94. for (i = 0; i < n2; i++)
  95. ipiv_B[i] += n1;
  96. }