You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

complex_test.cc 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <memory>
  17. #include "common/common_test.h"
  18. #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
  19. namespace mindspore {
  20. class TestComplex : public UT::Common {
  21. public:
  22. TestComplex() {}
  23. };
  24. TEST_F(TestComplex, test_size) {
  25. ASSERT_EQ(sizeof(Complex<float>), 2 * sizeof(float));
  26. ASSERT_EQ(sizeof(Complex<double>), 2 * sizeof(double));
  27. ASSERT_EQ(alignof(Complex<float>), 2 * sizeof(float));
  28. ASSERT_EQ(alignof(Complex<double>), 2 * sizeof(double));
  29. }
  30. template <typename T>
  31. void test_construct() {
  32. constexpr T real = T(1.11f);
  33. constexpr T imag = T(2.22f);
  34. ASSERT_EQ(Complex<T>().real(), T());
  35. ASSERT_EQ(Complex<T>().imag(), T());
  36. ASSERT_EQ(Complex<T>(real, imag).real(), real);
  37. ASSERT_EQ(Complex<T>(real, imag).imag(), imag);
  38. ASSERT_EQ(Complex<T>(real).real(), real);
  39. ASSERT_EQ(Complex<T>(real).imag(), T());
  40. }
  41. template <typename T1, typename T2>
  42. void test_conver_construct() {
  43. ASSERT_EQ(Complex<T1>(Complex<T2>(T2(1.11f), T2(2.22f))).real(), T1(1.11f));
  44. ASSERT_EQ(Complex<T1>(Complex<T2>(T2(1.11f), T2(2.22f))).imag(), T1(2.22f));
  45. }
  46. template <typename T>
  47. void test_conver_std_construct() {
  48. ASSERT_EQ(Complex<T>(std::complex<T>(T(1.11f), T(2.22f))).real(), T(1.11f));
  49. ASSERT_EQ(Complex<T>(std::complex<T>(T(1.11f), T(2.22f))).imag(), T(2.22f));
  50. }
  51. TEST_F(TestComplex, test_construct) {
  52. test_construct<float>();
  53. test_construct<double>();
  54. test_conver_construct<float, float>();
  55. test_conver_construct<double, double>();
  56. test_conver_construct<float, double>();
  57. test_conver_construct<double, float>();
  58. test_conver_std_construct<float>();
  59. test_conver_std_construct<double>();
  60. }
  61. template <typename T>
  62. void test_convert_operator(T &&a) {
  63. ASSERT_EQ(static_cast<T>(Complex<float>(a)), a);
  64. }
  65. TEST_F(TestComplex, test_convert_operator) {
  66. test_convert_operator<bool>(true);
  67. test_convert_operator<signed char>(1);
  68. test_convert_operator<unsigned char>(1);
  69. ASSERT_NEAR(static_cast<double>(Complex<float>(1.11)), 1.11, 0.001);
  70. test_convert_operator<float>(1.11f);
  71. test_convert_operator<int16_t>(1);
  72. test_convert_operator<uint16_t>(1);
  73. test_convert_operator<int32_t>(1);
  74. test_convert_operator<uint32_t>(1);
  75. test_convert_operator<int64_t>(1);
  76. test_convert_operator<uint64_t>(1);
  77. float16 a(1.11f);
  78. ASSERT_EQ(static_cast<float16>(Complex<float>(a)), a);
  79. }
  80. TEST_F(TestComplex, test_assign_operator) {
  81. Complex<float> a = 1.11f;
  82. std::cout << a << std::endl;
  83. ASSERT_EQ(a.real(), 1.11f);
  84. ASSERT_EQ(a.imag(), float());
  85. a = Complex<double>(2.22f, 1.11f);
  86. ASSERT_EQ(a.real(), 2.22f);
  87. ASSERT_EQ(a.imag(), 1.11f);
  88. }
  89. template <typename T1, typename T2, typename T3>
  90. void test_arithmetic_add(T1 lhs, T2 rhs, T3 r) {
  91. ASSERT_EQ(lhs + rhs, r);
  92. if constexpr (!(std::is_same<T1, float>::value || std::is_same<T1, double>::value)) {
  93. ASSERT_EQ(lhs += rhs, r);
  94. }
  95. }
  96. template <typename T1, typename T2, typename T3>
  97. void test_arithmetic_sub(T1 lhs, T2 rhs, T3 r) {
  98. ASSERT_EQ(lhs - rhs, r);
  99. if constexpr (!(std::is_same<T1, float>::value || std::is_same<T1, double>::value)) {
  100. ASSERT_EQ(lhs -= rhs, r);
  101. }
  102. }
  103. template <typename T1, typename T2, typename T3>
  104. void test_arithmetic_mul(T1 lhs, T2 rhs, T3 r) {
  105. ASSERT_EQ(lhs * rhs, r);
  106. if constexpr (!(std::is_same<T1, float>::value || std::is_same<T1, double>::value)) {
  107. ASSERT_EQ(lhs *= rhs, r);
  108. }
  109. }
  110. template <typename T1, typename T2, typename T3>
  111. void test_arithmetic_div(T1 lhs, T2 rhs, T3 r) {
  112. ASSERT_EQ(lhs / rhs, r);
  113. if constexpr (!(std::is_same<T1, float>::value || std::is_same<T1, double>::value)) {
  114. ASSERT_EQ(lhs /= rhs, r);
  115. }
  116. }
  117. TEST_F(TestComplex, test_arithmetic) {
  118. test_arithmetic_add<Complex<float>, Complex<float>, Complex<float>>(
  119. Complex<float>(1.11, 2.22), Complex<float>(1.11, 2.22), Complex<float>(2.22, 4.44));
  120. test_arithmetic_add<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11,
  121. Complex<float>(2.22, 2.22));
  122. test_arithmetic_add<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22),
  123. Complex<float>(2.22, 2.22));
  124. test_arithmetic_sub<Complex<float>, Complex<float>, Complex<float>>(Complex<float>(1.11, 2.22),
  125. Complex<float>(1.11, 2.22), Complex<float>(0, 0));
  126. test_arithmetic_sub<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11, Complex<float>(0, 2.22));
  127. test_arithmetic_sub<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22),
  128. Complex<float>(0, -2.22));
  129. test_arithmetic_mul<Complex<float>, Complex<float>, Complex<float>>(
  130. Complex<float>(1.11, 2.22), Complex<float>(1.11, 2.22), Complex<float>(-3.6963, 4.9284));
  131. test_arithmetic_mul<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11,
  132. Complex<float>(1.2321, 2.4642));
  133. test_arithmetic_mul<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22),
  134. Complex<float>(1.2321, 2.4642));
  135. test_arithmetic_div<Complex<float>, Complex<float>, Complex<float>>(Complex<float>(1.11, 2.22),
  136. Complex<float>(1.11, 2.22), Complex<float>(1, 0));
  137. test_arithmetic_div<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11, Complex<float>(1, 2));
  138. test_arithmetic_div<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22),
  139. Complex<float>(0.2, -0.4));
  140. }
  141. } // namespace mindspore