You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ctrsyl.c 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. #include "relapack.h"
  2. static void RELAPACK_ctrsyl_rec(const char *, const char *, const blasint *,
  3. const blasint *, const blasint *, const float *, const blasint *, const float *,
  4. const blasint *, float *, const blasint *, float *, blasint *);
  5. /** CTRSYL solves the complex Sylvester matrix equation.
  6. *
  7. * This routine is functionally equivalent to LAPACK's ctrsyl.
  8. * For details on its interface, see
  9. * http://www.netlib.org/lapack/explore-html/d8/df4/ctrsyl_8f.html
  10. * */
  11. void RELAPACK_ctrsyl(
  12. const char *tranA, const char *tranB, const blasint *isgn,
  13. const blasint *m, const blasint *n,
  14. const float *A, const blasint *ldA, const float *B, const blasint *ldB,
  15. float *C, const blasint *ldC, float *scale,
  16. blasint *info
  17. ) {
  18. // Check arguments
  19. const blasint notransA = LAPACK(lsame)(tranA, "N");
  20. const blasint ctransA = LAPACK(lsame)(tranA, "C");
  21. const blasint notransB = LAPACK(lsame)(tranB, "N");
  22. const blasint ctransB = LAPACK(lsame)(tranB, "C");
  23. *info = 0;
  24. if (!ctransA && !notransA)
  25. *info = -1;
  26. else if (!ctransB && !notransB)
  27. *info = -2;
  28. else if (*isgn != 1 && *isgn != -1)
  29. *info = -3;
  30. else if (*m < 0)
  31. *info = -4;
  32. else if (*n < 0)
  33. *info = -5;
  34. else if (*ldA < MAX(1, *m))
  35. *info = -7;
  36. else if (*ldB < MAX(1, *n))
  37. *info = -9;
  38. else if (*ldC < MAX(1, *m))
  39. *info = -11;
  40. if (*info) {
  41. const blasint minfo = -*info;
  42. LAPACK(xerbla)("CTRSYL", &minfo, strlen("CTRSYL"));
  43. return;
  44. }
  45. if (*m == 0 || *n == 0) {
  46. *scale = 1.;
  47. return;
  48. }
  49. // Clean char * arguments
  50. const char cleantranA = notransA ? 'N' : 'C';
  51. const char cleantranB = notransB ? 'N' : 'C';
  52. // Recursive kernel
  53. RELAPACK_ctrsyl_rec(&cleantranA, &cleantranB, isgn, m, n, A, ldA, B, ldB, C, ldC, scale, info);
  54. }
  55. /** ctrsyl's recursive compute kernel */
  56. static void RELAPACK_ctrsyl_rec(
  57. const char *tranA, const char *tranB, const blasint *isgn,
  58. const blasint *m, const blasint *n,
  59. const float *A, const blasint *ldA, const float *B, const blasint *ldB,
  60. float *C, const blasint *ldC, float *scale,
  61. blasint *info
  62. ) {
  63. if (*m <= MAX(CROSSOVER_CTRSYL, 1) && *n <= MAX(CROSSOVER_CTRSYL, 1)) {
  64. // Unblocked
  65. RELAPACK_ctrsyl_rec2(tranA, tranB, isgn, m, n, A, ldA, B, ldB, C, ldC, scale, info);
  66. return;
  67. }
  68. // Constants
  69. const float ONE[] = { 1., 0. };
  70. const float MONE[] = { -1., 0. };
  71. const float MSGN[] = { -*isgn, 0. };
  72. const blasint iONE[] = { 1 };
  73. // Outputs
  74. float scale1[] = { 1., 0. };
  75. float scale2[] = { 1., 0. };
  76. blasint info1[] = { 0 };
  77. blasint info2[] = { 0 };
  78. if (*m > *n) {
  79. // Splitting
  80. const blasint m1 = CREC_SPLIT(*m);
  81. const blasint m2 = *m - m1;
  82. // A_TL A_TR
  83. // 0 A_BR
  84. const float *const A_TL = A;
  85. const float *const A_TR = A + 2 * *ldA * m1;
  86. const float *const A_BR = A + 2 * *ldA * m1 + 2 * m1;
  87. // C_T
  88. // C_B
  89. float *const C_T = C;
  90. float *const C_B = C + 2 * m1;
  91. if (*tranA == 'N') {
  92. // recusion(A_BR, B, C_B)
  93. RELAPACK_ctrsyl_rec(tranA, tranB, isgn, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, scale1, info1);
  94. // C_T = C_T - A_TR * C_B
  95. BLAS(cgemm)("N", "N", &m1, n, &m2, MONE, A_TR, ldA, C_B, ldC, scale1, C_T, ldC);
  96. // recusion(A_TL, B, C_T)
  97. RELAPACK_ctrsyl_rec(tranA, tranB, isgn, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, scale2, info2);
  98. // apply scale
  99. if (scale2[0] != 1)
  100. LAPACK(clascl)("G", iONE, iONE, ONE, scale2, &m2, n, C_B, ldC, info);
  101. } else {
  102. // recusion(A_TL, B, C_T)
  103. RELAPACK_ctrsyl_rec(tranA, tranB, isgn, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, scale1, info1);
  104. // C_B = C_B - A_TR' * C_T
  105. BLAS(cgemm)("C", "N", &m2, n, &m1, MONE, A_TR, ldA, C_T, ldC, scale1, C_B, ldC);
  106. // recusion(A_BR, B, C_B)
  107. RELAPACK_ctrsyl_rec(tranA, tranB, isgn, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, scale2, info2);
  108. // apply scale
  109. if (scale2[0] != 1)
  110. LAPACK(clascl)("G", iONE, iONE, ONE, scale2, &m1, n, C_B, ldC, info);
  111. }
  112. } else {
  113. // Splitting
  114. const blasint n1 = CREC_SPLIT(*n);
  115. const blasint n2 = *n - n1;
  116. // B_TL B_TR
  117. // 0 B_BR
  118. const float *const B_TL = B;
  119. const float *const B_TR = B + 2 * *ldB * n1;
  120. const float *const B_BR = B + 2 * *ldB * n1 + 2 * n1;
  121. // C_L C_R
  122. float *const C_L = C;
  123. float *const C_R = C + 2 * *ldC * n1;
  124. if (*tranB == 'N') {
  125. // recusion(A, B_TL, C_L)
  126. RELAPACK_ctrsyl_rec(tranA, tranB, isgn, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, scale1, info1);
  127. // C_R = C_R -/+ C_L * B_TR
  128. BLAS(cgemm)("N", "N", m, &n2, &n1, MSGN, C_L, ldC, B_TR, ldB, scale1, C_R, ldC);
  129. // recusion(A, B_BR, C_R)
  130. RELAPACK_ctrsyl_rec(tranA, tranB, isgn, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, scale2, info2);
  131. // apply scale
  132. if (scale2[0] != 1)
  133. LAPACK(clascl)("G", iONE, iONE, ONE, scale2, m, &n1, C_L, ldC, info);
  134. } else {
  135. // recusion(A, B_BR, C_R)
  136. RELAPACK_ctrsyl_rec(tranA, tranB, isgn, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, scale1, info1);
  137. // C_L = C_L -/+ C_R * B_TR'
  138. BLAS(cgemm)("N", "C", m, &n1, &n2, MSGN, C_R, ldC, B_TR, ldB, scale1, C_L, ldC);
  139. // recusion(A, B_TL, C_L)
  140. RELAPACK_ctrsyl_rec(tranA, tranB, isgn, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, scale2, info2);
  141. // apply scale
  142. if (scale2[0] != 1)
  143. LAPACK(clascl)("G", iONE, iONE, ONE, scale2, m, &n2, C_R, ldC, info);
  144. }
  145. }
  146. *scale = scale1[0] * scale2[0];
  147. *info = info1[0] || info2[0];
  148. }