You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sgetrf.c 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. #include "relapack.h"
  2. static void RELAPACK_sgetrf_rec(const blasint *, const blasint *, float *, const blasint *,
  3. blasint *, blasint *);
  4. /** SGETRF computes an LU factorization of a general M-by-N matrix A using partial pivoting with row interchanges.
  5. *
  6. * This routine is functionally equivalent to LAPACK's sgetrf.
  7. * For details on its interface, see
  8. * http://www.netlib.org/lapack/explore-html/de/de2/sgetrf_8f.html
  9. * */
  10. void RELAPACK_sgetrf(
  11. const blasint *m, const blasint *n,
  12. float *A, const blasint *ldA, blasint *ipiv,
  13. blasint *info
  14. ) {
  15. // Check arguments
  16. *info = 0;
  17. if (*m < 0)
  18. *info = -1;
  19. else if (*n < 0)
  20. *info = -2;
  21. else if (*ldA < MAX(1, *m))
  22. *info = -4;
  23. if (*info) {
  24. const blasint minfo = -*info;
  25. LAPACK(xerbla)("SGETRF", &minfo, strlen("SGETRF"));
  26. return;
  27. }
  28. if (*m == 0 || *n == 0) return;
  29. const blasint sn = MIN(*m, *n);
  30. RELAPACK_sgetrf_rec(m, &sn, A, ldA, ipiv, info);
  31. // Right remainder
  32. if (*m < *n) {
  33. // Constants
  34. const float ONE[] = { 1. };
  35. const blasint iONE[] = { 1 };
  36. // Splitting
  37. const blasint rn = *n - *m;
  38. // A_L A_R
  39. const float *const A_L = A;
  40. float *const A_R = A + *ldA * *m;
  41. // A_R = apply(ipiv, A_R)
  42. LAPACK(slaswp)(&rn, A_R, ldA, iONE, m, ipiv, iONE);
  43. // A_R = A_L \ A_R
  44. BLAS(strsm)("L", "L", "N", "U", m, &rn, ONE, A_L, ldA, A_R, ldA);
  45. }
  46. }
  47. /** sgetrf's recursive compute kernel */
  48. static void RELAPACK_sgetrf_rec(
  49. const blasint *m, const blasint *n,
  50. float *A, const blasint *ldA, blasint *ipiv,
  51. blasint *info
  52. ) {
  53. if (*m == 0 || *n == 0) return;
  54. if ( *n <= MAX(CROSSOVER_SGETRF, 1)) {
  55. // Unblocked
  56. LAPACK(sgetrf2)(m, n, A, ldA, ipiv, info);
  57. return;
  58. }
  59. // Constants
  60. const float ONE[] = { 1. };
  61. const float MONE[] = { -1. };
  62. const blasint iONE[] = { 1 };
  63. // Splitting
  64. const blasint n1 = SREC_SPLIT(*n);
  65. const blasint n2 = *n - n1;
  66. const blasint m2 = *m - n1;
  67. // A_L A_R
  68. float *const A_L = A;
  69. float *const A_R = A + *ldA * n1;
  70. // A_TL A_TR
  71. // A_BL A_BR
  72. float *const A_TL = A;
  73. float *const A_TR = A + *ldA * n1;
  74. float *const A_BL = A + n1;
  75. float *const A_BR = A + *ldA * n1 + n1;
  76. // ipiv_T
  77. // ipiv_B
  78. blasint *const ipiv_T = ipiv;
  79. blasint *const ipiv_B = ipiv + n1;
  80. // recursion(A_L, ipiv_T)
  81. RELAPACK_sgetrf_rec(m, &n1, A_L, ldA, ipiv_T, info);
  82. if (*info)
  83. return;
  84. // apply pivots to A_R
  85. LAPACK(slaswp)(&n2, A_R, ldA, iONE, &n1, ipiv_T, iONE);
  86. // A_TR = A_TL \ A_TR
  87. BLAS(strsm)("L", "L", "N", "U", &n1, &n2, ONE, A_TL, ldA, A_TR, ldA);
  88. // A_BR = A_BR - A_BL * A_TR
  89. BLAS(sgemm)("N", "N", &m2, &n2, &n1, MONE, A_BL, ldA, A_TR, ldA, ONE, A_BR, ldA);
  90. // recursion(A_BR, ipiv_B)
  91. RELAPACK_sgetrf_rec(&m2, &n2, A_BR, ldA, ipiv_B, info);
  92. if (*info)
  93. *info += n1;
  94. // apply pivots to A_BL
  95. LAPACK(slaswp)(&n1, A_BL, ldA, iONE, &n2, ipiv_B, iONE);
  96. // shift pivots
  97. blasint i;
  98. for (i = 0; i < n2; i++)
  99. ipiv_B[i] += n1;
  100. }