You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cc_implementations_test.cc 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <memory>
  18. #include <vector>
  19. #include "common/common_test.h"
  20. #include "frontend/operator/cc_implementations.h"
  21. namespace mindspore {
  22. namespace prim {
  23. class TestImplementations : public UT::Common {
  24. public:
  25. TestImplementations() {}
  26. virtual void SetUp() {}
  27. };
  28. TEST_F(TestImplementations, ScalarAddTest) {
  29. ValuePtrList list;
  30. list.push_back(MakeValue(static_cast<int64_t>(1)));
  31. list.push_back(MakeValue(static_cast<int64_t>(2)));
  32. ASSERT_EQ(ScalarAdd(list)->cast<Int64ImmPtr>()->value(), 3);
  33. list.clear();
  34. list.push_back(MakeValue(1.0f));
  35. list.push_back(MakeValue(1.5f));
  36. ASSERT_EQ(ScalarAdd(list)->cast<FP32ImmPtr>()->value(), 2.5f);
  37. list.clear();
  38. list.push_back(MakeValue(3.0));
  39. list.push_back(MakeValue(0.5));
  40. ASSERT_EQ(ScalarAdd(list)->cast<FP64ImmPtr>()->value(), 3.5);
  41. list.clear();
  42. list.push_back(MakeValue(INT64_MAX));
  43. list.push_back(MakeValue(static_cast<int64_t>(2)));
  44. try {
  45. ScalarAdd(list);
  46. FAIL();
  47. } catch (std::runtime_error const &err) {
  48. ASSERT_TRUE(std::string(err.what()).find("Overflow of the sum of two signed number") != std::string::npos);
  49. }
  50. list.clear();
  51. list.push_back(MakeValue(INT64_MIN));
  52. list.push_back(MakeValue(static_cast<int64_t>(-1)));
  53. try {
  54. ScalarAdd(list);
  55. FAIL();
  56. } catch (std::runtime_error const &err) {
  57. ASSERT_TRUE(std::string(err.what()).find("Overflow of the sum of two signed number") != std::string::npos);
  58. }
  59. list.clear();
  60. }
  61. TEST_F(TestImplementations, ScalarSubTest) {
  62. ValuePtrList list;
  63. list.push_back(MakeValue(static_cast<int64_t>(1)));
  64. list.push_back(MakeValue(static_cast<int64_t>(3)));
  65. ASSERT_EQ(ScalarSub(list)->cast<Int64ImmPtr>()->value(), -2);
  66. list.clear();
  67. list.push_back(MakeValue(1.0f));
  68. list.push_back(MakeValue(1.5f));
  69. ASSERT_EQ(ScalarSub(list)->cast<FP32ImmPtr>()->value(), -0.5f);
  70. list.clear();
  71. list.push_back(MakeValue(3.0));
  72. list.push_back(MakeValue(0.5));
  73. ASSERT_EQ(ScalarSub(list)->cast<FP64ImmPtr>()->value(), 2.5);
  74. list.clear();
  75. list.push_back(MakeValue(INT64_MAX));
  76. list.push_back(MakeValue(static_cast<int64_t>(-1)));
  77. try {
  78. ScalarSub(list);
  79. FAIL();
  80. } catch (std::runtime_error const &err) {
  81. ASSERT_TRUE(std::string(err.what()).find("Overflow of the sub of two signed number") != std::string::npos);
  82. }
  83. list.clear();
  84. list.push_back(MakeValue(INT64_MIN));
  85. list.push_back(MakeValue(static_cast<int64_t>(1)));
  86. try {
  87. ScalarSub(list);
  88. FAIL();
  89. } catch (std::runtime_error const &err) {
  90. ASSERT_TRUE(std::string(err.what()).find("Overflow of the sub of two signed number") != std::string::npos);
  91. }
  92. list.clear();
  93. }
  94. TEST_F(TestImplementations, ScalarMulTest) {
  95. ValuePtrList list;
  96. list.push_back(MakeValue(static_cast<int64_t>(2)));
  97. list.push_back(MakeValue(static_cast<int64_t>(3)));
  98. ASSERT_EQ(ScalarMul(list)->cast<Int64ImmPtr>()->value(), 6);
  99. list.clear();
  100. list.push_back(MakeValue(2.0f));
  101. list.push_back(MakeValue(1.5f));
  102. ASSERT_EQ(ScalarMul(list)->cast<FP32ImmPtr>()->value(), 3.0f);
  103. list.clear();
  104. list.push_back(MakeValue(-2.0));
  105. list.push_back(MakeValue(-4.0));
  106. ASSERT_EQ(ScalarMul(list)->cast<FP64ImmPtr>()->value(), 8.0);
  107. list.clear();
  108. list.push_back(MakeValue(static_cast<int64_t>(10)));
  109. list.push_back(MakeValue(INT64_MAX));
  110. try {
  111. ScalarMul(list);
  112. FAIL();
  113. } catch (std::runtime_error const &err) {
  114. ASSERT_TRUE(std::string(err.what()).find("Overflow of the mul of two signed number") != std::string::npos);
  115. }
  116. list.clear();
  117. list.push_back(MakeValue(INT64_MIN));
  118. list.push_back(MakeValue(static_cast<int64_t>(-1)));
  119. try {
  120. ScalarMul(list);
  121. FAIL();
  122. } catch (std::runtime_error const &err) {
  123. ASSERT_TRUE(std::string(err.what()).find("Overflow of the mul of two signed number") != std::string::npos);
  124. }
  125. list.clear();
  126. list.push_back(MakeValue(static_cast<int64_t>(-2)));
  127. list.push_back(MakeValue(INT64_MAX));
  128. try {
  129. ScalarMul(list);
  130. FAIL();
  131. } catch (std::runtime_error const &err) {
  132. ASSERT_TRUE(std::string(err.what()).find("Overflow of the mul of two signed number") != std::string::npos);
  133. }
  134. list.clear();
  135. list.push_back(MakeValue(static_cast<int64_t>(2)));
  136. list.push_back(MakeValue(INT64_MIN));
  137. try {
  138. ScalarMul(list);
  139. FAIL();
  140. } catch (std::runtime_error const &err) {
  141. ASSERT_TRUE(std::string(err.what()).find("Overflow of the mul of two signed number") != std::string::npos);
  142. }
  143. list.clear();
  144. list.push_back(MakeValue(static_cast<int64_t>(0)));
  145. list.push_back(MakeValue(INT64_MIN));
  146. ASSERT_EQ(ScalarDiv(list)->cast<Int64ImmPtr>()->value(), 0);
  147. list.clear();
  148. }
  149. TEST_F(TestImplementations, ScalarDivTest) {
  150. ValuePtrList list;
  151. list.push_back(MakeValue(static_cast<int64_t>(6)));
  152. list.push_back(MakeValue(static_cast<int64_t>(3)));
  153. ASSERT_EQ(ScalarDiv(list)->cast<Int64ImmPtr>()->value(), 2);
  154. list.clear();
  155. list.push_back(MakeValue(3.0f));
  156. list.push_back(MakeValue(1.5f));
  157. ASSERT_EQ(ScalarDiv(list)->cast<FP32ImmPtr>()->value(), 2.0f);
  158. list.clear();
  159. list.push_back(MakeValue(-4.0));
  160. list.push_back(MakeValue(2.0));
  161. ASSERT_EQ(ScalarDiv(list)->cast<FP64ImmPtr>()->value(), -2.0);
  162. list.clear();
  163. list.push_back(MakeValue(INT64_MAX));
  164. list.push_back(MakeValue(static_cast<int64_t>(0)));
  165. try {
  166. ScalarDiv(list);
  167. FAIL();
  168. } catch (std::runtime_error const &err) {
  169. ASSERT_TRUE(std::string(err.what()).find("Divisor could not be zero") != std::string::npos);
  170. }
  171. list.clear();
  172. list.push_back(MakeValue(INT64_MIN));
  173. list.push_back(MakeValue(static_cast<int64_t>(-1)));
  174. try {
  175. ScalarDiv(list);
  176. FAIL();
  177. } catch (std::runtime_error const &err) {
  178. ASSERT_TRUE(std::string(err.what()).find("Overflow of the div of two signed number") != std::string::npos);
  179. }
  180. list.clear();
  181. list.push_back(MakeValue(static_cast<int64_t>(-1)));
  182. list.push_back(MakeValue(INT64_MIN));
  183. ASSERT_EQ(ScalarDiv(list)->cast<Int64ImmPtr>()->value(), 0);
  184. list.clear();
  185. }
  186. TEST_F(TestImplementations, ScalarModTest) {
  187. ValuePtrList list;
  188. list.push_back(MakeValue(static_cast<int64_t>(7)));
  189. list.push_back(MakeValue(static_cast<int64_t>(3)));
  190. ASSERT_EQ(ScalarMod(list)->cast<Int64ImmPtr>()->value(), 1);
  191. list.clear();
  192. list.push_back(MakeValue(static_cast<int64_t>(-8)));
  193. list.push_back(MakeValue(static_cast<int64_t>(3)));
  194. ASSERT_EQ(ScalarMod(list)->cast<Int64ImmPtr>()->value(), -2);
  195. list.clear();
  196. list.push_back(MakeValue(static_cast<int64_t>(-9)));
  197. list.push_back(MakeValue(static_cast<int64_t>(2)));
  198. ASSERT_EQ(ScalarMod(list)->cast<Int64ImmPtr>()->value(), -1);
  199. list.clear();
  200. list.push_back(MakeValue(INT64_MIN));
  201. list.push_back(MakeValue(static_cast<int64_t>(0)));
  202. try {
  203. ScalarMod(list);
  204. FAIL();
  205. } catch (std::runtime_error const &err) {
  206. ASSERT_TRUE(std::string(err.what()).find("Could not mod to zero") != std::string::npos);
  207. }
  208. list.clear();
  209. list.push_back(MakeValue(INT64_MIN));
  210. list.push_back(MakeValue(static_cast<int64_t>(-1)));
  211. try {
  212. ScalarMod(list);
  213. FAIL();
  214. } catch (std::runtime_error const &err) {
  215. ASSERT_TRUE(std::string(err.what()).find("Overflow of the mod of two signed number") != std::string::npos);
  216. }
  217. list.clear();
  218. }
  219. TEST_F(TestImplementations, ScalarUAddTest) {
  220. ValuePtrList list;
  221. list.push_back(MakeValue((uint64_t)1));
  222. ASSERT_EQ(ScalarUAdd(list)->cast<UInt64ImmPtr>()->value(), 1);
  223. list.clear();
  224. }
  225. TEST_F(TestImplementations, ScalarLogTest) {
  226. ValuePtrList list;
  227. list.push_back(MakeValue(static_cast<double>(7.3890560989306495)));
  228. ASSERT_EQ(ScalarLog(list)->cast<FP64ImmPtr>()->value(), 2.0);
  229. list.clear();
  230. }
  231. TEST_F(TestImplementations, ScalarUSubTest) {
  232. ValuePtrList list;
  233. list.push_back(MakeValue(static_cast<int64_t>(1)));
  234. ASSERT_EQ(ScalarUSub(list)->cast<Int64ImmPtr>()->value(), -1);
  235. list.clear();
  236. }
  237. TEST_F(TestImplementations, ScalarEqTest) {
  238. ValuePtrList list;
  239. list.push_back(MakeValue(1.0f));
  240. list.push_back(MakeValue(1.0f));
  241. ASSERT_EQ(ScalarEq(list)->cast<BoolImmPtr>()->value(), true);
  242. list.clear();
  243. list.push_back(MakeValue(1.0f));
  244. list.push_back(MakeValue(-1.0f));
  245. ASSERT_EQ(ScalarEq(list)->cast<BoolImmPtr>()->value(), false);
  246. list.clear();
  247. list.push_back(MakeValue(1.0f));
  248. list.push_back(MakeValue(1.0));
  249. ASSERT_EQ(ScalarEq(list)->cast<BoolImmPtr>()->value(), true);
  250. list.clear();
  251. list.push_back(MakeValue(1.0));
  252. list.push_back(MakeValue(1.0));
  253. ASSERT_EQ(ScalarEq(list)->cast<BoolImmPtr>()->value(), true);
  254. list.clear();
  255. }
  256. TEST_F(TestImplementations, ScalarLtTest) {
  257. ValuePtrList list;
  258. list.push_back(MakeValue(1.0f));
  259. list.push_back(MakeValue(1.0f));
  260. ASSERT_EQ(ScalarLt(list)->cast<BoolImmPtr>()->value(), false);
  261. list.clear();
  262. list.push_back(MakeValue(1.0f));
  263. list.push_back(MakeValue(-1.0f));
  264. ASSERT_EQ(ScalarLt(list)->cast<BoolImmPtr>()->value(), false);
  265. list.clear();
  266. list.push_back(MakeValue(1.0f));
  267. list.push_back(MakeValue(2.5));
  268. ASSERT_EQ(ScalarLt(list)->cast<BoolImmPtr>()->value(), true);
  269. list.clear();
  270. list.push_back(MakeValue(2.5));
  271. list.push_back(MakeValue(3.0));
  272. ASSERT_EQ(ScalarLt(list)->cast<BoolImmPtr>()->value(), true);
  273. list.clear();
  274. }
  275. TEST_F(TestImplementations, ScalarGtTest) {
  276. ValuePtrList list;
  277. list.push_back(MakeValue(1.0f));
  278. list.push_back(MakeValue(2.0f));
  279. ASSERT_EQ(ScalarGt(list)->cast<BoolImmPtr>()->value(), false);
  280. list.clear();
  281. list.push_back(MakeValue(2.0f));
  282. list.push_back(MakeValue(-1.0f));
  283. ASSERT_EQ(ScalarGt(list)->cast<BoolImmPtr>()->value(), true);
  284. list.clear();
  285. list.push_back(MakeValue(2.0f));
  286. list.push_back(MakeValue(2.0));
  287. ASSERT_EQ(ScalarGt(list)->cast<BoolImmPtr>()->value(), false);
  288. list.clear();
  289. list.push_back(MakeValue(2.5));
  290. list.push_back(MakeValue(2.0));
  291. ASSERT_EQ(ScalarGt(list)->cast<BoolImmPtr>()->value(), true);
  292. list.clear();
  293. }
  294. TEST_F(TestImplementations, ScalarNeTest) {
  295. ValuePtrList list;
  296. list.push_back(MakeValue(1.0f));
  297. list.push_back(MakeValue(1.0f));
  298. ASSERT_EQ(ScalarNe(list)->cast<BoolImmPtr>()->value(), false);
  299. list.clear();
  300. list.push_back(MakeValue(1.0f));
  301. list.push_back(MakeValue(-1.0f));
  302. ASSERT_EQ(ScalarNe(list)->cast<BoolImmPtr>()->value(), true);
  303. list.clear();
  304. list.push_back(MakeValue(1.0f));
  305. list.push_back(MakeValue(2.0));
  306. ASSERT_EQ(ScalarNe(list)->cast<BoolImmPtr>()->value(), true);
  307. list.clear();
  308. list.push_back(MakeValue(2.0));
  309. list.push_back(MakeValue(2.0));
  310. ASSERT_EQ(ScalarNe(list)->cast<BoolImmPtr>()->value(), false);
  311. list.clear();
  312. }
  313. TEST_F(TestImplementations, ScalarLeTest) {
  314. ValuePtrList list;
  315. list.push_back(MakeValue(1.0f));
  316. list.push_back(MakeValue(1.0f));
  317. ASSERT_EQ(ScalarLe(list)->cast<BoolImmPtr>()->value(), true);
  318. list.clear();
  319. list.push_back(MakeValue(1.0f));
  320. list.push_back(MakeValue(-1.0f));
  321. ASSERT_EQ(ScalarLe(list)->cast<BoolImmPtr>()->value(), false);
  322. list.clear();
  323. list.push_back(MakeValue(1.0f));
  324. list.push_back(MakeValue(2.0));
  325. ASSERT_EQ(ScalarLe(list)->cast<BoolImmPtr>()->value(), true);
  326. list.clear();
  327. list.push_back(MakeValue(6.0));
  328. list.push_back(MakeValue(-1.0f));
  329. ASSERT_EQ(ScalarLe(list)->cast<BoolImmPtr>()->value(), false);
  330. list.clear();
  331. }
  332. TEST_F(TestImplementations, ScalarGeTest) {
  333. ValuePtrList list;
  334. list.push_back(MakeValue(1.0f));
  335. list.push_back(MakeValue(1.0f));
  336. ASSERT_EQ(ScalarGe(list)->cast<BoolImmPtr>()->value(), true);
  337. list.clear();
  338. list.push_back(MakeValue(1.0f));
  339. list.push_back(MakeValue(-1.0f));
  340. ASSERT_EQ(ScalarGe(list)->cast<BoolImmPtr>()->value(), true);
  341. list.clear();
  342. list.push_back(MakeValue(1.0f));
  343. list.push_back(MakeValue(2.0));
  344. ASSERT_EQ(ScalarGe(list)->cast<BoolImmPtr>()->value(), false);
  345. list.clear();
  346. list.push_back(MakeValue(6.0));
  347. list.push_back(MakeValue(-1.0f));
  348. ASSERT_EQ(ScalarGe(list)->cast<BoolImmPtr>()->value(), true);
  349. list.clear();
  350. }
  351. TEST_F(TestImplementations, BoolNotTest) {
  352. ValuePtrList list;
  353. list.push_back(MakeValue(true));
  354. ASSERT_EQ(BoolNot(list)->cast<BoolImmPtr>()->value(), false);
  355. list.clear();
  356. list.push_back(MakeValue(false));
  357. ASSERT_EQ(BoolNot(list)->cast<BoolImmPtr>()->value(), true);
  358. list.clear();
  359. }
  360. TEST_F(TestImplementations, BoolAndTest) {
  361. ValuePtrList list;
  362. list.push_back(MakeValue(true));
  363. list.push_back(MakeValue(false));
  364. ASSERT_EQ(BoolAnd(list)->cast<BoolImmPtr>()->value(), false);
  365. list.clear();
  366. list.push_back(MakeValue(true));
  367. list.push_back(MakeValue(true));
  368. ASSERT_EQ(BoolAnd(list)->cast<BoolImmPtr>()->value(), true);
  369. list.clear();
  370. list.push_back(MakeValue(false));
  371. list.push_back(MakeValue(false));
  372. ASSERT_EQ(BoolAnd(list)->cast<BoolImmPtr>()->value(), false);
  373. list.clear();
  374. }
  375. TEST_F(TestImplementations, BoolOrTest) {
  376. ValuePtrList list;
  377. list.push_back(MakeValue(true));
  378. list.push_back(MakeValue(false));
  379. ASSERT_EQ(BoolOr(list)->cast<BoolImmPtr>()->value(), true);
  380. list.clear();
  381. list.push_back(MakeValue(true));
  382. list.push_back(MakeValue(true));
  383. ASSERT_EQ(BoolOr(list)->cast<BoolImmPtr>()->value(), true);
  384. list.clear();
  385. list.push_back(MakeValue(false));
  386. list.push_back(MakeValue(false));
  387. ASSERT_EQ(BoolOr(list)->cast<BoolImmPtr>()->value(), false);
  388. list.clear();
  389. }
  390. TEST_F(TestImplementations, BoolEqTest) {
  391. ValuePtrList list;
  392. list.push_back(MakeValue(true));
  393. list.push_back(MakeValue(false));
  394. ASSERT_EQ(BoolEq(list)->cast<BoolImmPtr>()->value(), false);
  395. list.clear();
  396. list.push_back(MakeValue(true));
  397. list.push_back(MakeValue(true));
  398. ASSERT_EQ(BoolEq(list)->cast<BoolImmPtr>()->value(), true);
  399. list.clear();
  400. list.push_back(MakeValue(false));
  401. list.push_back(MakeValue(false));
  402. ASSERT_EQ(BoolEq(list)->cast<BoolImmPtr>()->value(), true);
  403. list.clear();
  404. }
  405. } // namespace prim
  406. } // namespace mindspore