You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops_test.cc 13 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <memory>
  18. #include <vector>
  19. #include "common/common_test.h"
  20. #include "ir/value.h"
  21. #include "pybind_api/ir/primitive_py.h"
  22. #include "pipeline/jit/parse/python_adapter.h"
  23. #include "frontend/operator/ops.h"
  24. #include "base/core_ops.h"
  25. namespace mindspore {
  26. namespace prim {
  27. class TestOps : public UT::Common {
  28. public:
  29. TestOps() {}
  30. virtual void SetUp() {}
  31. };
  32. // Arithmetic
  33. TEST_F(TestOps, ScalarAddTest) {
  34. auto prim = std::make_shared<Primitive>(prim::kScalarAdd);
  35. ASSERT_EQ(prim->name(), kPrimScalarAdd->name());
  36. }
  37. TEST_F(TestOps, ScalarSubTest) {
  38. auto prim = std::make_shared<Primitive>(prim::kScalarSub);
  39. ASSERT_EQ(prim->name(), kPrimScalarSub->name());
  40. }
  41. TEST_F(TestOps, ScalarMulTest) {
  42. auto prim = std::make_shared<Primitive>(prim::kScalarMul);
  43. ASSERT_EQ(prim->name(), kPrimScalarMul->name());
  44. }
  45. TEST_F(TestOps, ScalarDivTest) {
  46. auto prim = std::make_shared<Primitive>(prim::kScalarDiv);
  47. ASSERT_EQ(prim->name(), kPrimScalarDiv->name());
  48. }
  49. TEST_F(TestOps, ScalarModTest) {
  50. auto prim = std::make_shared<Primitive>(prim::kScalarMod);
  51. ASSERT_EQ(prim->name(), kPrimScalarMod->name());
  52. }
  53. TEST_F(TestOps, ScalarPowTest) {
  54. auto prim = std::make_shared<Primitive>(prim::kScalarPow);
  55. ASSERT_EQ(prim->name(), kPrimScalarPow->name());
  56. }
  57. TEST_F(TestOps, ScalarTruncTest) {
  58. auto prim = std::make_shared<Primitive>(prim::kScalarTrunc);
  59. ASSERT_EQ(prim->name(), kPrimScalarTrunc->name());
  60. }
  61. TEST_F(TestOps, ScalarFloorTest) {
  62. auto prim = std::make_shared<Primitive>(prim::kScalarFloor);
  63. ASSERT_EQ(prim->name(), kPrimScalarFloor->name());
  64. }
  65. TEST_F(TestOps, ScalarUaddTest) {
  66. auto prim = std::make_shared<Primitive>(prim::kScalarUadd);
  67. ASSERT_EQ(prim->name(), kPrimScalarUadd->name());
  68. }
  69. TEST_F(TestOps, ScalarUsubTest) {
  70. auto prim = std::make_shared<Primitive>(prim::kScalarUsub);
  71. ASSERT_EQ(prim->name(), kPrimScalarUsub->name());
  72. }
  73. TEST_F(TestOps, ScalarExpTest) {
  74. auto prim = std::make_shared<Primitive>("scalar_exp");
  75. ASSERT_EQ(prim->name(), kPrimScalarExp->name());
  76. }
  77. TEST_F(TestOps, ScalarLogTest) {
  78. auto prim = std::make_shared<Primitive>("scalar_log");
  79. ASSERT_EQ(prim->name(), kPrimScalarLog->name());
  80. }
  81. TEST_F(TestOps, ScalarSinTest) {
  82. auto prim = std::make_shared<Primitive>("scalar_sin");
  83. ASSERT_EQ(prim->name(), kPrimScalarSin->name());
  84. }
  85. TEST_F(TestOps, ScalarCosTest) {
  86. auto prim = std::make_shared<Primitive>("scalar_cos");
  87. ASSERT_EQ(prim->name(), kPrimScalarCos->name());
  88. }
  89. TEST_F(TestOps, ScalarTanTest) {
  90. auto prim = std::make_shared<Primitive>("scalar_tan");
  91. ASSERT_EQ(prim->name(), kPrimScalarTan->name());
  92. }
  93. // Comparisons
  94. TEST_F(TestOps, ScalarEqTest) {
  95. auto prim = std::make_shared<Primitive>("scalar_eq");
  96. ASSERT_EQ(prim->name(), kPrimScalarEq->name());
  97. }
  98. TEST_F(TestOps, ScalarLtTest) {
  99. auto prim = std::make_shared<Primitive>("scalar_lt");
  100. ASSERT_EQ(prim->name(), kPrimScalarLt->name());
  101. }
  102. TEST_F(TestOps, ScalarGtTest) {
  103. auto prim = std::make_shared<Primitive>("scalar_gt");
  104. ASSERT_EQ(prim->name(), kPrimScalarGt->name());
  105. }
  106. TEST_F(TestOps, ScalarNeTest) {
  107. auto prim = std::make_shared<Primitive>("scalar_ne");
  108. ASSERT_EQ(prim->name(), kPrimScalarNe->name());
  109. }
  110. TEST_F(TestOps, ScalarLeTest) {
  111. auto prim = std::make_shared<Primitive>("scalar_le");
  112. ASSERT_EQ(prim->name(), kPrimScalarLe->name());
  113. }
  114. TEST_F(TestOps, ScalarGeTest) {
  115. auto prim = std::make_shared<Primitive>("scalar_ge");
  116. ASSERT_EQ(prim->name(), kPrimScalarGe->name());
  117. }
  118. TEST_F(TestOps, BoolNotTest) {
  119. auto prim = std::make_shared<Primitive>("bool_not");
  120. ASSERT_EQ(prim->name(), kPrimBoolNot->name());
  121. }
  122. TEST_F(TestOps, BoolAndTest) {
  123. auto prim = std::make_shared<Primitive>("bool_and");
  124. ASSERT_EQ(prim->name(), kPrimBoolAnd->name());
  125. }
  126. TEST_F(TestOps, BoolOrTest) {
  127. auto prim = std::make_shared<Primitive>("bool_or");
  128. ASSERT_EQ(prim->name(), kPrimBoolOr->name());
  129. }
  130. TEST_F(TestOps, BoolEqTest) {
  131. auto prim = std::make_shared<Primitive>("bool_eq");
  132. ASSERT_EQ(prim->name(), kPrimBoolEq->name());
  133. }
  134. // Type introspection
  135. TEST_F(TestOps, TypeOfTest) {
  136. auto prim = std::make_shared<Primitive>("typeof");
  137. ASSERT_EQ(prim->name(), kPrimTypeOf->name());
  138. }
  139. TEST_F(TestOps, HasTypeTest) {
  140. auto prim = std::make_shared<Primitive>("hastype");
  141. ASSERT_EQ(prim->name(), kPrimHasType->name());
  142. }
  143. // Data structures
  144. TEST_F(TestOps, MakeTupleTest) {
  145. auto prim = std::make_shared<Primitive>("MakeTuple");
  146. ASSERT_EQ(prim->name(), kPrimMakeTuple->name());
  147. }
  148. TEST_F(TestOps, MakeListTest) {
  149. auto prim = std::make_shared<Primitive>("make_list");
  150. ASSERT_EQ(prim->name(), kPrimMakeList->name());
  151. }
  152. TEST_F(TestOps, MakeRecordTest) {
  153. auto prim = std::make_shared<Primitive>("make_record");
  154. ASSERT_EQ(prim->name(), kPrimMakeRecord->name());
  155. }
  156. TEST_F(TestOps, TupleGetItemTest) {
  157. auto prim = std::make_shared<Primitive>(kTupleGetItem);
  158. ASSERT_EQ(prim->name(), kPrimTupleGetItem->name());
  159. }
  160. TEST_F(TestOps, ListGetItemTest) {
  161. auto prim = std::make_shared<Primitive>("list_getitem");
  162. ASSERT_EQ(prim->name(), kPrimListGetItem->name());
  163. }
  164. TEST_F(TestOps, ArrayGetItemTest) {
  165. auto prim = std::make_shared<Primitive>("array_getitem");
  166. ASSERT_EQ(prim->name(), kPrimArrayGetItem->name());
  167. }
  168. TEST_F(TestOps, TupleSetItemTest) {
  169. auto prim = std::make_shared<Primitive>("tuple_setitem");
  170. ASSERT_EQ(prim->name(), kPrimTupleSetItem->name());
  171. }
  172. TEST_F(TestOps, ListSetItemTest) {
  173. auto prim = std::make_shared<Primitive>("list_setitem");
  174. ASSERT_EQ(prim->name(), kPrimListSetItem->name());
  175. }
  176. TEST_F(TestOps, ArraySetItemTest) {
  177. auto prim = std::make_shared<Primitive>("array_setitem");
  178. ASSERT_EQ(prim->name(), kPrimArraySetItem->name());
  179. }
  180. TEST_F(TestOps, ListAppendTest) {
  181. auto prim = std::make_shared<Primitive>("list_append");
  182. ASSERT_EQ(prim->name(), kPrimListAppend->name());
  183. }
  184. TEST_F(TestOps, GetAttrTest) {
  185. auto prim = std::make_shared<Primitive>("getattr");
  186. ASSERT_EQ(prim->name(), kPrimGetAttr->name());
  187. }
  188. TEST_F(TestOps, TupleLenTest) {
  189. auto prim = std::make_shared<Primitive>("tuple_len");
  190. ASSERT_EQ(prim->name(), kPrimTupleLen->name());
  191. }
  192. TEST_F(TestOps, ListLenTest) {
  193. auto prim = std::make_shared<Primitive>("list_len");
  194. ASSERT_EQ(prim->name(), kPrimListLen->name());
  195. }
  196. TEST_F(TestOps, ArrayLenTest) {
  197. auto prim = std::make_shared<Primitive>("array_len");
  198. ASSERT_EQ(prim->name(), kPrimArrayLen->name());
  199. }
  200. TEST_F(TestOps, ListMapTest) {
  201. auto prim = std::make_shared<Primitive>("list_map");
  202. ASSERT_EQ(prim->name(), kPrimListMap->name());
  203. }
  204. TEST_F(TestOps, ListReduceTest) {
  205. auto prim = std::make_shared<Primitive>("list_reduce");
  206. ASSERT_EQ(prim->name(), kPrimListReduce->name());
  207. }
  208. // Arrays
  209. TEST_F(TestOps, ScalarToArrayTest) {
  210. auto prim = std::make_shared<Primitive>("scalar_to_array");
  211. ASSERT_EQ(prim->name(), kPrimScalarToArray->name());
  212. }
  213. TEST_F(TestOps, ArrayToScalarTest) {
  214. auto prim = std::make_shared<Primitive>("array_to_scalar");
  215. ASSERT_EQ(prim->name(), kPrimArrayToScalar->name());
  216. }
  217. TEST_F(TestOps, BroadCastShapeTest) {
  218. auto prim = std::make_shared<Primitive>("broadcast_shape");
  219. ASSERT_EQ(prim->name(), kPrimBroadcastShape->name());
  220. }
  221. TEST_F(TestOps, ArrayMapTest) {
  222. auto prim = std::make_shared<Primitive>("array_map");
  223. ASSERT_EQ(prim->name(), kPrimArrayMap->name());
  224. }
  225. TEST_F(TestOps, ArrayReduceTest) {
  226. auto prim = std::make_shared<Primitive>("array_reduce");
  227. ASSERT_EQ(prim->name(), kPrimArrayReduce->name());
  228. }
  229. TEST_F(TestOps, DistributeTest) {
  230. auto prim = std::make_shared<Primitive>("distribute");
  231. ASSERT_EQ(prim->name(), kPrimDistribute->name());
  232. }
  233. TEST_F(TestOps, TransposeTest) {
  234. auto prim = std::make_shared<Primitive>("Transpose");
  235. ASSERT_EQ(prim->name(), kPrimTranspose->name());
  236. }
  237. TEST_F(TestOps, Im2ColTest) {
  238. auto prim = std::make_shared<Primitive>("im2col");
  239. ASSERT_EQ(prim->name(), kPrimIm2Col->name());
  240. }
  241. TEST_F(TestOps, Col2ImTest) {
  242. auto prim = std::make_shared<Primitive>("col2im");
  243. ASSERT_EQ(prim->name(), kPrimCol2Im->name());
  244. }
  245. TEST_F(TestOps, Im2ColV1Test) {
  246. auto prim = std::make_shared<Primitive>("im2col_v1");
  247. ASSERT_EQ(prim->name(), kPrimIm2ColV1->name());
  248. }
  249. TEST_F(TestOps, Col2ImV1Test) {
  250. auto prim = std::make_shared<Primitive>("col2im_v1");
  251. ASSERT_EQ(prim->name(), kPrimCol2ImV1->name());
  252. }
  253. // Statements
  254. TEST_F(TestOps, SwitchTest) {
  255. auto prim = std::make_shared<Primitive>("Switch");
  256. ASSERT_EQ(prim->name(), kPrimSwitch->name());
  257. }
  258. TEST_F(TestOps, ReturnTest) {
  259. auto prim = std::make_shared<Primitive>("Return");
  260. ASSERT_EQ(prim->name(), kPrimReturn->name());
  261. }
  262. // Miscellaneous
  263. TEST_F(TestOps, IdentityTest) {
  264. auto prim = std::make_shared<Primitive>("identity");
  265. ASSERT_EQ(prim->name(), kPrimIdentity->name());
  266. }
  267. TEST_F(TestOps, ResolveTest) {
  268. auto prim = std::make_shared<Primitive>("resolve");
  269. ASSERT_EQ(prim->name(), kPrimResolve->name());
  270. }
  271. TEST_F(TestOps, PartialTest) {
  272. auto prim = std::make_shared<Primitive>("Partial");
  273. ASSERT_EQ(prim->name(), kPrimPartial->name());
  274. }
  275. TEST_F(TestOps, JTest) {
  276. auto prim = std::make_shared<Primitive>("J");
  277. ASSERT_EQ(prim->name(), kPrimJ->name());
  278. }
  279. TEST_F(TestOps, EmbedTest) {
  280. auto prim = std::make_shared<Primitive>("embed");
  281. ASSERT_EQ(prim->name(), kPrimEmbed->name());
  282. }
  283. TEST_F(TestOps, EnvSetItemTest) {
  284. auto prim = std::make_shared<Primitive>("env_setitem");
  285. ASSERT_EQ(prim->name(), kPrimEnvSetItem->name());
  286. }
  287. TEST_F(TestOps, EnvGetItemTest) {
  288. auto prim = std::make_shared<Primitive>("env_getitem");
  289. ASSERT_EQ(prim->name(), kPrimEnvGetItem->name());
  290. }
  291. TEST_F(TestOps, EnvAddest) {
  292. auto prim = std::make_shared<Primitive>("env_add");
  293. ASSERT_EQ(prim->name(), kPrimEnvAdd->name());
  294. }
  295. // Neural Network
  296. TEST_F(TestOps, Conv2dTest) {
  297. auto prim = std::make_shared<Primitive>("Conv2D");
  298. ASSERT_EQ(prim->name(), kPrimConv2D->name());
  299. }
  300. TEST_F(TestOps, Conv2dAttrTest) {
  301. Primitive prim("Conv2D");
  302. prim.SetAttrs({
  303. {"stride", MakeValue(static_cast<int64_t>(3))},
  304. {"pad", MakeValue(static_cast<int64_t>(1))},
  305. });
  306. ASSERT_EQ(prim.name(), kPrimConv2D->name());
  307. Int64Imm stride(3);
  308. Int64Imm pad(1);
  309. ASSERT_EQ(*prim.GetAttr("stride"), stride);
  310. ASSERT_EQ(*prim.GetAttr("pad"), pad);
  311. }
  312. TEST_F(TestOps, CustomOpAttrTest) {
  313. Primitive prim("CustomOp", true, kPrimTypePyInferShape);
  314. prim.SetAttrs({
  315. {"attr1", MakeValue(static_cast<int64_t>(3))},
  316. {"attr2", MakeValue(static_cast<int64_t>(1))},
  317. });
  318. ASSERT_EQ(prim.name(), std::string("CustomOp"));
  319. ASSERT_EQ(prim.prim_type(), kPrimTypePyInferShape);
  320. auto attrs = prim.attrs();
  321. for (auto attr : attrs) {
  322. std::string prim_name = attr.first;
  323. auto prim_value = attr.second;
  324. std::cout << prim_name << std::endl;
  325. std::cout << prim_value << std::endl;
  326. }
  327. }
  328. TEST_F(TestOps, Conv2dBackpropInputTest) {
  329. auto prim = std::make_shared<Primitive>("Conv2DBackpropInput");
  330. ASSERT_EQ(prim->name(), kPrimConv2DBackpropInput->name());
  331. }
  332. TEST_F(TestOps, Conv2dBackpropFilterTest) {
  333. auto prim = std::make_shared<Primitive>("Conv2DBackpropFilter");
  334. ASSERT_EQ(prim->name(), kPrimConv2DBackpropFilter->name());
  335. }
  336. TEST_F(TestOps, ReluTest) {
  337. auto prim = std::make_shared<Primitive>("ReLU");
  338. ASSERT_EQ(prim->name(), kPrimRelu->name());
  339. }
  340. TEST_F(TestOps, PoolingTest) {
  341. auto prim = std::make_shared<Primitive>("Pooling");
  342. ASSERT_EQ(prim->name(), kPrimPooling->name());
  343. }
  344. TEST_F(TestOps, GetConv2DPrimPyTest) {
  345. auto conv2d_prim = prim::GetPythonOps("conv2d_prim", "gtest_input.pynative");
  346. ASSERT_TRUE(conv2d_prim);
  347. PrimitivePyPtr conv2d_ptr = dyn_cast<PrimitivePy>(conv2d_prim);
  348. ASSERT_TRUE(conv2d_ptr);
  349. if (nullptr != conv2d_ptr) {
  350. MS_LOG(INFO) << "Get PrimitivePyPtr: " << conv2d_ptr->name();
  351. if(!conv2d_ptr->HasComputeFunction()){
  352. MS_LOG(EXCEPTION) << "" << conv2d_ptr->name() << "'s compute function is not implemented";
  353. }
  354. py::object conv2d_pyobj = parse::python_adapter::GetPyFn("gtest_input.pynative", "conv2d_prim");
  355. py::dict opAttrs = py::getattr(conv2d_pyobj, "attrs");
  356. std::unordered_map<std::string, ValuePtr> attrs{};
  357. for (auto item : opAttrs) {
  358. if (!py::isinstance<py::str>(item.first)) {
  359. MS_LOG(EXCEPTION) << "type error in py dict convert";
  360. }
  361. std::string name = py::cast<std::string>(item.first);
  362. MS_LOG(INFO) << "Attr name: " << name;
  363. ValuePtr converted_ret;
  364. parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
  365. MS_LOG(INFO) << "Attr value: " << converted_ret->ToString();
  366. attrs.emplace(name, converted_ret);
  367. }
  368. }
  369. MS_LOG(INFO) << "Finish GetPyFnTest!";
  370. }
  371. } // namespace prim
  372. } // namespace mindspore