You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

strtri.c 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. #include "relapack.h"
  2. static void RELAPACK_strtri_rec(const char *, const char *, const blasint *,
  3. float *, const blasint *, blasint *);
  4. /** CTRTRI computes the inverse of a real upper or lower triangular matrix A.
  5. *
  6. * This routine is functionally equivalent to LAPACK's strtri.
  7. * For details on its interface, see
  8. * http://www.netlib.org/lapack/explore-html/de/d76/strtri_8f.html
  9. * */
  10. void RELAPACK_strtri(
  11. const char *uplo, const char *diag, const blasint *n,
  12. float *A, const blasint *ldA,
  13. blasint *info
  14. ) {
  15. // Check arguments
  16. const blasint lower = LAPACK(lsame)(uplo, "L");
  17. const blasint upper = LAPACK(lsame)(uplo, "U");
  18. const blasint nounit = LAPACK(lsame)(diag, "N");
  19. const blasint unit = LAPACK(lsame)(diag, "U");
  20. *info = 0;
  21. if (!lower && !upper)
  22. *info = -1;
  23. else if (!nounit && !unit)
  24. *info = -2;
  25. else if (*n < 0)
  26. *info = -3;
  27. else if (*ldA < MAX(1, *n))
  28. *info = -5;
  29. if (*info) {
  30. const blasint minfo = -*info;
  31. LAPACK(xerbla)("STRTRI", &minfo, strlen("STRTRI"));
  32. return;
  33. }
  34. // Clean char * arguments
  35. const char cleanuplo = lower ? 'L' : 'U';
  36. const char cleandiag = nounit ? 'N' : 'U';
  37. // check for singularity
  38. if (nounit) {
  39. blasint i;
  40. for (i = 0; i < *n; i++)
  41. if (A[i + *ldA * i] == 0) {
  42. *info = i;
  43. return;
  44. }
  45. }
  46. // Recursive kernel
  47. RELAPACK_strtri_rec(&cleanuplo, &cleandiag, n, A, ldA, info);
  48. }
  49. /** strtri's recursive compute kernel */
  50. static void RELAPACK_strtri_rec(
  51. const char *uplo, const char *diag, const blasint *n,
  52. float *A, const blasint *ldA,
  53. blasint *info
  54. ){
  55. if (*n <= MAX(CROSSOVER_STRTRI, 1)) {
  56. // Unblocked
  57. LAPACK(strti2)(uplo, diag, n, A, ldA, info);
  58. return;
  59. }
  60. // Constants
  61. const float ONE[] = { 1. };
  62. const float MONE[] = { -1. };
  63. // Splitting
  64. const blasint n1 = SREC_SPLIT(*n);
  65. const blasint n2 = *n - n1;
  66. // A_TL A_TR
  67. // A_BL A_BR
  68. float *const A_TL = A;
  69. float *const A_TR = A + *ldA * n1;
  70. float *const A_BL = A + n1;
  71. float *const A_BR = A + *ldA * n1 + n1;
  72. // recursion(A_TL)
  73. RELAPACK_strtri_rec(uplo, diag, &n1, A_TL, ldA, info);
  74. if (*info)
  75. return;
  76. if (*uplo == 'L') {
  77. // A_BL = - A_BL * A_TL
  78. BLAS(strmm)("R", "L", "N", diag, &n2, &n1, MONE, A_TL, ldA, A_BL, ldA);
  79. // A_BL = A_BR \ A_BL
  80. BLAS(strsm)("L", "L", "N", diag, &n2, &n1, ONE, A_BR, ldA, A_BL, ldA);
  81. } else {
  82. // A_TR = - A_TL * A_TR
  83. BLAS(strmm)("L", "U", "N", diag, &n1, &n2, MONE, A_TL, ldA, A_TR, ldA);
  84. // A_TR = A_TR / A_BR
  85. BLAS(strsm)("R", "U", "N", diag, &n1, &n2, ONE, A_BR, ldA, A_TR, ldA);
  86. }
  87. // recursion(A_BR)
  88. RELAPACK_strtri_rec(uplo, diag, &n2, A_BR, ldA, info);
  89. if (*info)
  90. *info += n1;
  91. }