You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dpbtrf.c 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. #include "relapack.h"
  2. #include "stdlib.h"
  3. static void RELAPACK_dpbtrf_rec(const char *, const blasint *, const blasint *,
  4. double *, const blasint *, double *, const blasint *, blasint *);
  5. /** DPBTRF computes the Cholesky factorization of a real symmetric positive definite band matrix A.
  6. *
  7. * This routine is functionally equivalent to LAPACK's dpbtrf.
  8. * For details on its interface, see
  9. * http://www.netlib.org/lapack/explore-html/df/da9/dpbtrf_8f.html
  10. * */
  11. void RELAPACK_dpbtrf(
  12. const char *uplo, const blasint *n, const blasint *kd,
  13. double *Ab, const blasint *ldAb,
  14. blasint *info
  15. ) {
  16. // Check arguments
  17. const blasint lower = LAPACK(lsame)(uplo, "L");
  18. const blasint upper = LAPACK(lsame)(uplo, "U");
  19. *info = 0;
  20. if (!lower && !upper)
  21. *info = -1;
  22. else if (*n < 0)
  23. *info = -2;
  24. else if (*kd < 0)
  25. *info = -3;
  26. else if (*ldAb < *kd + 1)
  27. *info = -5;
  28. if (*info) {
  29. const blasint minfo = -*info;
  30. LAPACK(xerbla)("DPBTRF", &minfo, strlen("DPBTRF"));
  31. return;
  32. }
  33. // Clean char * arguments
  34. const char cleanuplo = lower ? 'L' : 'U';
  35. // Constant
  36. const double ZERO[] = { 0. };
  37. // Allocate work space
  38. const blasint n1 = DREC_SPLIT(*n);
  39. const blasint mWork = (*kd > n1) ? (lower ? *n - *kd : n1) : *kd;
  40. const blasint nWork = (*kd > n1) ? (lower ? n1 : *n - *kd) : *kd;
  41. double *Work = malloc(mWork * nWork * sizeof(double));
  42. LAPACK(dlaset)(uplo, &mWork, &nWork, ZERO, ZERO, Work, &mWork);
  43. // Recursive kernel
  44. RELAPACK_dpbtrf_rec(&cleanuplo, n, kd, Ab, ldAb, Work, &mWork, info);
  45. // Free work space
  46. free(Work);
  47. }
  48. /** dpbtrf's recursive compute kernel */
  49. static void RELAPACK_dpbtrf_rec(
  50. const char *uplo, const blasint *n, const blasint *kd,
  51. double *Ab, const blasint *ldAb,
  52. double *Work, const blasint *ldWork,
  53. blasint *info
  54. ){
  55. if (*n <= MAX(CROSSOVER_DPBTRF, 1)) {
  56. // Unblocked
  57. LAPACK(dpbtf2)(uplo, n, kd, Ab, ldAb, info);
  58. return;
  59. }
  60. // Constants
  61. const double ONE[] = { 1. };
  62. const double MONE[] = { -1. };
  63. // Unskew A
  64. const blasint ldA[] = { *ldAb - 1 };
  65. double *const A = Ab + ((*uplo == 'L') ? 0 : *kd);
  66. // Splitting
  67. const blasint n1 = MIN(DREC_SPLIT(*n), *kd);
  68. const blasint n2 = *n - n1;
  69. // * *
  70. // * Ab_BR
  71. double *const Ab_BR = Ab + *ldAb * n1;
  72. // A_TL A_TR
  73. // A_BL A_BR
  74. double *const A_TL = A;
  75. double *const A_TR = A + *ldA * n1;
  76. double *const A_BL = A + n1;
  77. double *const A_BR = A + *ldA * n1 + n1;
  78. // recursion(A_TL)
  79. RELAPACK_dpotrf(uplo, &n1, A_TL, ldA, info);
  80. if (*info)
  81. return;
  82. // Banded splitting
  83. const blasint n21 = MIN(n2, *kd - n1);
  84. const blasint n22 = MIN(n2 - n21, n1);
  85. // n1 n21 n22
  86. // n1 * A_TRl A_TRr
  87. // n21 A_BLt A_BRtl A_BRtr
  88. // n22 A_BLb A_BRbl A_BRbr
  89. double *const A_TRl = A_TR;
  90. double *const A_TRr = A_TR + *ldA * n21;
  91. double *const A_BLt = A_BL;
  92. double *const A_BLb = A_BL + n21;
  93. double *const A_BRtl = A_BR;
  94. double *const A_BRtr = A_BR + *ldA * n21;
  95. double *const A_BRbl = A_BR + n21;
  96. double *const A_BRbr = A_BR + *ldA * n21 + n21;
  97. if (*uplo == 'L') {
  98. // A_BLt = ABLt / A_TL'
  99. BLAS(dtrsm)("R", "L", "T", "N", &n21, &n1, ONE, A_TL, ldA, A_BLt, ldA);
  100. // A_BRtl = A_BRtl - A_BLt * A_BLt'
  101. BLAS(dsyrk)("L", "N", &n21, &n1, MONE, A_BLt, ldA, ONE, A_BRtl, ldA);
  102. // Work = A_BLb
  103. LAPACK(dlacpy)("U", &n22, &n1, A_BLb, ldA, Work, ldWork);
  104. // Work = Work / A_TL'
  105. BLAS(dtrsm)("R", "L", "T", "N", &n22, &n1, ONE, A_TL, ldA, Work, ldWork);
  106. // A_BRbl = A_BRbl - Work * A_BLt'
  107. BLAS(dgemm)("N", "T", &n22, &n21, &n1, MONE, Work, ldWork, A_BLt, ldA, ONE, A_BRbl, ldA);
  108. // A_BRbr = A_BRbr - Work * Work'
  109. BLAS(dsyrk)("L", "N", &n22, &n1, MONE, Work, ldWork, ONE, A_BRbr, ldA);
  110. // A_BLb = Work
  111. LAPACK(dlacpy)("U", &n22, &n1, Work, ldWork, A_BLb, ldA);
  112. } else {
  113. // A_TRl = A_TL' \ A_TRl
  114. BLAS(dtrsm)("L", "U", "T", "N", &n1, &n21, ONE, A_TL, ldA, A_TRl, ldA);
  115. // A_BRtl = A_BRtl - A_TRl' * A_TRl
  116. BLAS(dsyrk)("U", "T", &n21, &n1, MONE, A_TRl, ldA, ONE, A_BRtl, ldA);
  117. // Work = A_TRr
  118. LAPACK(dlacpy)("L", &n1, &n22, A_TRr, ldA, Work, ldWork);
  119. // Work = A_TL' \ Work
  120. BLAS(dtrsm)("L", "U", "T", "N", &n1, &n22, ONE, A_TL, ldA, Work, ldWork);
  121. // A_BRtr = A_BRtr - A_TRl' * Work
  122. BLAS(dgemm)("T", "N", &n21, &n22, &n1, MONE, A_TRl, ldA, Work, ldWork, ONE, A_BRtr, ldA);
  123. // A_BRbr = A_BRbr - Work' * Work
  124. BLAS(dsyrk)("U", "T", &n22, &n1, MONE, Work, ldWork, ONE, A_BRbr, ldA);
  125. // A_TRr = Work
  126. LAPACK(dlacpy)("L", &n1, &n22, Work, ldWork, A_TRr, ldA);
  127. }
  128. // recursion(A_BR)
  129. if (*kd > n1)
  130. RELAPACK_dpotrf(uplo, &n2, A_BR, ldA, info);
  131. else
  132. RELAPACK_dpbtrf_rec(uplo, &n2, kd, Ab_BR, ldAb, Work, ldWork, info);
  133. if (*info)
  134. *info += n1;
  135. }