You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

complex.h 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_UTILS_COPLEX_H_
  17. #define MINDSPORE_CCSRC_UTILS_COPLEX_H_
  18. #include <complex>
  19. #include <limits>
  20. #ifdef ENABLE_GPU
  21. #include <thrust/complex.h>
  22. #include <cublas_v2.h>
  23. #endif
  24. #include "base/float16.h"
  25. #if defined(__CUDACC__)
  26. #define HOST_DEVICE __host__ __device__
  27. #else
  28. #define HOST_DEVICE
  29. #endif
  30. namespace mindspore {
  31. namespace utils {
  32. // Implement Complex for mindspore, inspired by std::complex.
  33. template <typename T>
  34. struct alignas(sizeof(T) * 2) Complex {
  35. Complex() = default;
  36. ~Complex() = default;
  37. Complex(const Complex<T> &other) noexcept = default;
  38. Complex(Complex<T> &&other) noexcept = default;
  39. Complex &operator=(const Complex<T> &other) noexcept = default;
  40. Complex &operator=(Complex<T> &&other) noexcept = default;
  41. HOST_DEVICE inline constexpr Complex(const T &real, const T &imag = T()) : real_(real), imag_(imag) {}
  42. template <typename U>
  43. inline explicit constexpr Complex(const std::complex<U> &other) : Complex(other.real(), other.imag()) {}
  44. template <typename U>
  45. inline explicit constexpr operator std::complex<U>() const {
  46. return std::complex<U>(std::complex<T>(real(), imag()));
  47. }
  48. HOST_DEVICE inline explicit constexpr Complex(const float16 &real) : real_(static_cast<T>(real)), imag_(T()) {}
  49. #if defined(__CUDACC__)
  50. template <typename U>
  51. HOST_DEVICE inline explicit Complex(const thrust::complex<U> &other) : real_(other.real()), imag_(other.imag()) {}
  52. template <typename U>
  53. HOST_DEVICE inline HOST_DEVICE explicit operator thrust::complex<U>() const {
  54. return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
  55. }
  56. #endif
  57. template <typename U = T>
  58. HOST_DEVICE inline explicit Complex(const Complex<U> &other)
  59. : real_(static_cast<T>(other.real())), imag_(static_cast<T>(other.imag())) {}
  60. HOST_DEVICE inline explicit operator bool() const { return static_cast<bool>(real_) || static_cast<bool>(imag_); }
  61. HOST_DEVICE inline explicit operator signed char() const { return static_cast<signed char>(real_); }
  62. HOST_DEVICE inline explicit operator unsigned char() const { return static_cast<unsigned char>(real_); }
  63. HOST_DEVICE inline explicit operator double() const { return static_cast<double>(real_); }
  64. HOST_DEVICE inline explicit operator float() const { return static_cast<float>(real_); }
  65. HOST_DEVICE inline explicit operator int16_t() const { return static_cast<int16_t>(real_); }
  66. HOST_DEVICE inline explicit operator uint16_t() const { return static_cast<uint16_t>(real_); }
  67. HOST_DEVICE inline explicit operator int32_t() const { return static_cast<int32_t>(real_); }
  68. HOST_DEVICE inline explicit operator uint32_t() const { return static_cast<uint32_t>(real_); }
  69. HOST_DEVICE inline explicit operator int64_t() const { return static_cast<int64_t>(real_); }
  70. HOST_DEVICE inline explicit operator uint64_t() const { return static_cast<uint64_t>(real_); }
  71. #if defined(__CUDACC__)
  72. HOST_DEVICE inline explicit operator half() const { return static_cast<half>(real_); }
  73. #else
  74. inline explicit operator float16() const { return static_cast<float16>(real_); }
  75. #endif
  76. HOST_DEVICE inline Complex<T> &operator=(const T &real) {
  77. real_ = real;
  78. imag_ = T();
  79. return *this;
  80. }
  81. HOST_DEVICE inline Complex<T> &operator+=(const T &real) {
  82. real_ += real;
  83. return *this;
  84. }
  85. HOST_DEVICE inline Complex<T> &operator-=(const T &real) {
  86. real_ -= real;
  87. return *this;
  88. }
  89. HOST_DEVICE inline Complex<T> &operator*=(const T &real) {
  90. real_ *= real;
  91. imag_ *= real;
  92. return *this;
  93. }
  94. // Note: check division by zero before use it.
  95. HOST_DEVICE inline Complex<T> &operator/=(const T &real) {
  96. real_ /= real;
  97. imag_ /= real;
  98. return *this;
  99. }
  100. template <typename U>
  101. HOST_DEVICE inline Complex<T> &operator=(const Complex<U> &z) {
  102. real_ = z.real();
  103. imag_ = z.imag();
  104. return *this;
  105. }
  106. template <typename U>
  107. HOST_DEVICE inline Complex<T> &operator+=(const Complex<U> &z) {
  108. real_ += z.real();
  109. imag_ += z.imag();
  110. return *this;
  111. }
  112. template <typename U>
  113. HOST_DEVICE inline Complex<T> &operator-=(const Complex<U> &z) {
  114. real_ -= z.real();
  115. imag_ -= z.imag();
  116. return *this;
  117. }
  118. template <typename U>
  119. HOST_DEVICE inline Complex<T> &operator*=(const Complex<U> &z);
  120. // Note: check division by zero before use it.
  121. template <typename U>
  122. HOST_DEVICE inline Complex<T> &operator/=(const Complex<U> &z);
  123. HOST_DEVICE inline constexpr T real() const { return real_; }
  124. HOST_DEVICE inline constexpr T imag() const { return imag_; }
  125. HOST_DEVICE inline void real(T val) { real_ = val; }
  126. HOST_DEVICE inline void imag(T val) { imag_ = val; }
  127. private:
  128. T real_;
  129. T imag_;
  130. };
  131. template <typename T>
  132. template <typename U>
  133. HOST_DEVICE inline Complex<T> &Complex<T>::operator*=(const Complex<U> &z) {
  134. const T real = real_ * z.real() - imag_ * z.imag();
  135. imag_ = real_ * z.imag() + imag_ * z.real();
  136. real_ = real;
  137. return *this;
  138. }
  139. // Note: check division by zero before use it.
  140. template <typename T>
  141. template <typename U>
  142. HOST_DEVICE inline Complex<T> &Complex<T>::operator/=(const Complex<U> &z) {
  143. T a = real_;
  144. T b = imag_;
  145. U c = z.real();
  146. U d = z.imag();
  147. auto denominator = c * c + d * d;
  148. real_ = (a * c + b * d) / denominator;
  149. imag_ = (b * c - a * d) / denominator;
  150. return *this;
  151. }
  152. template <typename T>
  153. HOST_DEVICE inline Complex<T> operator+(const Complex<T> &lhs, const Complex<T> &rhs) {
  154. Complex<T> result = lhs;
  155. result += rhs;
  156. return result;
  157. }
  158. template <typename T>
  159. HOST_DEVICE inline Complex<T> operator+(const Complex<T> &lhs, const T &rhs) {
  160. Complex<T> result = lhs;
  161. result += rhs;
  162. return result;
  163. }
  164. template <typename T>
  165. HOST_DEVICE inline Complex<T> operator+(const T &lhs, const Complex<T> &rhs) {
  166. Complex<T> result = rhs;
  167. result += lhs;
  168. return result;
  169. }
  170. template <typename T>
  171. HOST_DEVICE inline Complex<T> operator-(const Complex<T> &lhs, const Complex<T> &rhs) {
  172. Complex<T> result = lhs;
  173. result -= rhs;
  174. return result;
  175. }
  176. template <typename T>
  177. HOST_DEVICE inline Complex<T> operator-(const Complex<T> &lhs, const T &rhs) {
  178. Complex<T> result = lhs;
  179. result -= rhs;
  180. return result;
  181. }
  182. template <typename T>
  183. HOST_DEVICE inline Complex<T> operator-(const T &lhs, const Complex<T> &rhs) {
  184. Complex<T> result(lhs, -rhs.imag());
  185. result -= rhs.real();
  186. return result;
  187. }
  188. template <typename T>
  189. HOST_DEVICE inline Complex<T> operator*(const Complex<T> &lhs, const Complex<T> &rhs) {
  190. Complex<T> result = lhs;
  191. result *= rhs;
  192. return result;
  193. }
  194. template <typename T>
  195. HOST_DEVICE inline Complex<T> operator*(const Complex<T> &lhs, const T &rhs) {
  196. Complex<T> result = lhs;
  197. result *= rhs;
  198. return result;
  199. }
  200. template <typename T>
  201. HOST_DEVICE inline Complex<T> operator*(const T &lhs, const Complex<T> &rhs) {
  202. Complex<T> result = rhs;
  203. result *= lhs;
  204. return result;
  205. }
  206. // Note: check division by zero before use it.
  207. template <typename T>
  208. HOST_DEVICE inline Complex<T> operator/(const Complex<T> &lhs, const Complex<T> &rhs) {
  209. Complex<T> result = lhs;
  210. result /= rhs;
  211. return result;
  212. }
  213. // Note: check division by zero before use it.
  214. template <typename T>
  215. HOST_DEVICE inline Complex<T> operator/(const Complex<T> &lhs, const T &rhs) {
  216. Complex<T> result = lhs;
  217. result /= rhs;
  218. return result;
  219. }
  220. // Note: check division by zero before use it.
  221. template <typename T>
  222. HOST_DEVICE inline Complex<T> operator/(const T &lhs, const Complex<T> &rhs) {
  223. Complex<T> result = lhs;
  224. result /= rhs;
  225. return result;
  226. }
  227. template <typename T>
  228. HOST_DEVICE inline Complex<T> operator+(const Complex<T> &z) {
  229. return z;
  230. }
  231. template <typename T>
  232. HOST_DEVICE inline Complex<T> operator-(const Complex<T> &z) {
  233. return Complex<T>(-z.real(), -z.imag());
  234. }
  235. template <typename T>
  236. HOST_DEVICE inline bool operator==(const Complex<T> &lhs, const Complex<T> &rhs) {
  237. return lhs.real() == rhs.real() && lhs.imag() == rhs.imag();
  238. }
  239. template <typename T>
  240. HOST_DEVICE inline bool operator==(const T &lhs, const Complex<T> &rhs) {
  241. return lhs == rhs.real() && rhs.imag() == 0;
  242. }
  243. template <typename T>
  244. HOST_DEVICE inline bool operator==(const Complex<T> &lhs, const T &rhs) {
  245. return lhs.real() == rhs && lhs.imag() == 0;
  246. }
  247. template <typename T>
  248. HOST_DEVICE inline bool operator!=(const Complex<T> &lhs, const Complex<T> &rhs) {
  249. return !(lhs == rhs);
  250. }
  251. template <typename T>
  252. HOST_DEVICE inline bool operator!=(const T &lhs, const Complex<T> &rhs) {
  253. return !(lhs == rhs);
  254. }
  255. template <typename T>
  256. HOST_DEVICE inline bool operator!=(const Complex<T> &lhs, const T &rhs) {
  257. return !(lhs == rhs);
  258. }
  259. template <typename T>
  260. inline std::ostream &operator<<(std::ostream &os, const Complex<T> &v) {
  261. return (os << std::noshowpos << v.real() << std::showpos << v.imag() << 'j');
  262. }
  263. template <typename T>
  264. HOST_DEVICE inline T abs(const Complex<T> &z) {
  265. #if defined(__CUDACC__)
  266. return thrust::abs(thrust::complex<T>(z));
  267. #else
  268. return std::abs(std::complex<T>(z));
  269. #endif
  270. }
  271. } // namespace utils
  272. } // namespace mindspore
  273. template <typename T>
  274. using Complex = mindspore::utils::Complex<T>;
  275. namespace std {
  276. template <typename T>
  277. class numeric_limits<mindspore::utils::Complex<T>> : public numeric_limits<T> {};
  278. } // namespace std
  279. #endif // MINDSPORE_CCSRC_UTILS_COPLEX_H_