You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops_test.cc 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <memory>
  18. #include <vector>
  19. #include "common/common_test.h"
  20. #include "ir/value.h"
  21. #include "pybind_api/ir/primitive_py.h"
  22. #include "pipeline/jit/parse/python_adapter.h"
  23. #include "frontend/operator/ops.h"
  24. namespace mindspore {
  25. namespace prim {
  26. class TestOps : public UT::Common {
  27. public:
  28. TestOps() {}
  29. virtual void SetUp() {}
  30. };
  31. // Arithmetic
  32. TEST_F(TestOps, ScalarAddTest) {
  33. auto prim = std::make_shared<Primitive>("scalar_add");
  34. ASSERT_EQ(prim->name(), kPrimScalarAdd->name());
  35. }
  36. TEST_F(TestOps, ScalarSubTest) {
  37. auto prim = std::make_shared<Primitive>("scalar_sub");
  38. ASSERT_EQ(prim->name(), kPrimScalarSub->name());
  39. }
  40. TEST_F(TestOps, ScalarMulTest) {
  41. auto prim = std::make_shared<Primitive>("scalar_mul");
  42. ASSERT_EQ(prim->name(), kPrimScalarMul->name());
  43. }
  44. TEST_F(TestOps, ScalarDivTest) {
  45. auto prim = std::make_shared<Primitive>("scalar_div");
  46. ASSERT_EQ(prim->name(), kPrimScalarDiv->name());
  47. }
  48. TEST_F(TestOps, ScalarModTest) {
  49. auto prim = std::make_shared<Primitive>("scalar_mod");
  50. ASSERT_EQ(prim->name(), kPrimScalarMod->name());
  51. }
  52. TEST_F(TestOps, ScalarPowTest) {
  53. auto prim = std::make_shared<Primitive>("scalar_pow");
  54. ASSERT_EQ(prim->name(), kPrimScalarPow->name());
  55. }
  56. TEST_F(TestOps, ScalarTruncTest) {
  57. auto prim = std::make_shared<Primitive>("scalar_trunc");
  58. ASSERT_EQ(prim->name(), kPrimScalarTrunc->name());
  59. }
  60. TEST_F(TestOps, ScalarFloorTest) {
  61. auto prim = std::make_shared<Primitive>("scalar_floor");
  62. ASSERT_EQ(prim->name(), kPrimScalarFloor->name());
  63. }
  64. TEST_F(TestOps, ScalarUaddTest) {
  65. auto prim = std::make_shared<Primitive>("scalar_uadd");
  66. ASSERT_EQ(prim->name(), kPrimScalarUadd->name());
  67. }
  68. TEST_F(TestOps, ScalarUsubTest) {
  69. auto prim = std::make_shared<Primitive>("scalar_usub");
  70. ASSERT_EQ(prim->name(), kPrimScalarUsub->name());
  71. }
  72. TEST_F(TestOps, ScalarExpTest) {
  73. auto prim = std::make_shared<Primitive>("scalar_exp");
  74. ASSERT_EQ(prim->name(), kPrimScalarExp->name());
  75. }
  76. TEST_F(TestOps, ScalarLogTest) {
  77. auto prim = std::make_shared<Primitive>("scalar_log");
  78. ASSERT_EQ(prim->name(), kPrimScalarLog->name());
  79. }
  80. TEST_F(TestOps, ScalarSinTest) {
  81. auto prim = std::make_shared<Primitive>("scalar_sin");
  82. ASSERT_EQ(prim->name(), kPrimScalarSin->name());
  83. }
  84. TEST_F(TestOps, ScalarCosTest) {
  85. auto prim = std::make_shared<Primitive>("scalar_cos");
  86. ASSERT_EQ(prim->name(), kPrimScalarCos->name());
  87. }
  88. TEST_F(TestOps, ScalarTanTest) {
  89. auto prim = std::make_shared<Primitive>("scalar_tan");
  90. ASSERT_EQ(prim->name(), kPrimScalarTan->name());
  91. }
  92. // Comparisons
  93. TEST_F(TestOps, ScalarEqTest) {
  94. auto prim = std::make_shared<Primitive>("scalar_eq");
  95. ASSERT_EQ(prim->name(), kPrimScalarEq->name());
  96. }
  97. TEST_F(TestOps, ScalarLtTest) {
  98. auto prim = std::make_shared<Primitive>("scalar_lt");
  99. ASSERT_EQ(prim->name(), kPrimScalarLt->name());
  100. }
  101. TEST_F(TestOps, ScalarGtTest) {
  102. auto prim = std::make_shared<Primitive>("scalar_gt");
  103. ASSERT_EQ(prim->name(), kPrimScalarGt->name());
  104. }
  105. TEST_F(TestOps, ScalarNeTest) {
  106. auto prim = std::make_shared<Primitive>("scalar_ne");
  107. ASSERT_EQ(prim->name(), kPrimScalarNe->name());
  108. }
  109. TEST_F(TestOps, ScalarLeTest) {
  110. auto prim = std::make_shared<Primitive>("scalar_le");
  111. ASSERT_EQ(prim->name(), kPrimScalarLe->name());
  112. }
  113. TEST_F(TestOps, ScalarGeTest) {
  114. auto prim = std::make_shared<Primitive>("scalar_ge");
  115. ASSERT_EQ(prim->name(), kPrimScalarGe->name());
  116. }
  117. TEST_F(TestOps, BoolNotTest) {
  118. auto prim = std::make_shared<Primitive>("bool_not");
  119. ASSERT_EQ(prim->name(), kPrimBoolNot->name());
  120. }
  121. TEST_F(TestOps, BoolAndTest) {
  122. auto prim = std::make_shared<Primitive>("bool_and");
  123. ASSERT_EQ(prim->name(), kPrimBoolAnd->name());
  124. }
  125. TEST_F(TestOps, BoolOrTest) {
  126. auto prim = std::make_shared<Primitive>("bool_or");
  127. ASSERT_EQ(prim->name(), kPrimBoolOr->name());
  128. }
  129. TEST_F(TestOps, BoolEqTest) {
  130. auto prim = std::make_shared<Primitive>("bool_eq");
  131. ASSERT_EQ(prim->name(), kPrimBoolEq->name());
  132. }
  133. // Type introspection
  134. TEST_F(TestOps, TypeOfTest) {
  135. auto prim = std::make_shared<Primitive>("typeof");
  136. ASSERT_EQ(prim->name(), kPrimTypeOf->name());
  137. }
  138. TEST_F(TestOps, HasTypeTest) {
  139. auto prim = std::make_shared<Primitive>("hastype");
  140. ASSERT_EQ(prim->name(), kPrimHasType->name());
  141. }
  142. // Data structures
  143. TEST_F(TestOps, MakeTupleTest) {
  144. auto prim = std::make_shared<Primitive>("make_tuple");
  145. ASSERT_EQ(prim->name(), kPrimMakeTuple->name());
  146. }
  147. TEST_F(TestOps, MakeListTest) {
  148. auto prim = std::make_shared<Primitive>("make_list");
  149. ASSERT_EQ(prim->name(), kPrimMakeList->name());
  150. }
  151. TEST_F(TestOps, MakeRecordTest) {
  152. auto prim = std::make_shared<Primitive>("make_record");
  153. ASSERT_EQ(prim->name(), kPrimMakeRecord->name());
  154. }
  155. TEST_F(TestOps, TupleGetItemTest) {
  156. auto prim = std::make_shared<Primitive>("tuple_getitem");
  157. ASSERT_EQ(prim->name(), kPrimTupleGetItem->name());
  158. }
  159. TEST_F(TestOps, ListGetItemTest) {
  160. auto prim = std::make_shared<Primitive>("list_getitem");
  161. ASSERT_EQ(prim->name(), kPrimListGetItem->name());
  162. }
  163. TEST_F(TestOps, ArrayGetItemTest) {
  164. auto prim = std::make_shared<Primitive>("array_getitem");
  165. ASSERT_EQ(prim->name(), kPrimArrayGetItem->name());
  166. }
  167. TEST_F(TestOps, TupleSetItemTest) {
  168. auto prim = std::make_shared<Primitive>("tuple_setitem");
  169. ASSERT_EQ(prim->name(), kPrimTupleSetItem->name());
  170. }
  171. TEST_F(TestOps, ListSetItemTest) {
  172. auto prim = std::make_shared<Primitive>("list_setitem");
  173. ASSERT_EQ(prim->name(), kPrimListSetItem->name());
  174. }
  175. TEST_F(TestOps, ArraySetItemTest) {
  176. auto prim = std::make_shared<Primitive>("array_setitem");
  177. ASSERT_EQ(prim->name(), kPrimArraySetItem->name());
  178. }
  179. TEST_F(TestOps, ListAppendTest) {
  180. auto prim = std::make_shared<Primitive>("list_append");
  181. ASSERT_EQ(prim->name(), kPrimListAppend->name());
  182. }
  183. TEST_F(TestOps, GetAttrTest) {
  184. auto prim = std::make_shared<Primitive>("getattr");
  185. ASSERT_EQ(prim->name(), kPrimGetAttr->name());
  186. }
  187. TEST_F(TestOps, TupleLenTest) {
  188. auto prim = std::make_shared<Primitive>("tuple_len");
  189. ASSERT_EQ(prim->name(), kPrimTupleLen->name());
  190. }
  191. TEST_F(TestOps, ListLenTest) {
  192. auto prim = std::make_shared<Primitive>("list_len");
  193. ASSERT_EQ(prim->name(), kPrimListLen->name());
  194. }
  195. TEST_F(TestOps, ArrayLenTest) {
  196. auto prim = std::make_shared<Primitive>("array_len");
  197. ASSERT_EQ(prim->name(), kPrimArrayLen->name());
  198. }
  199. TEST_F(TestOps, ListMapTest) {
  200. auto prim = std::make_shared<Primitive>("list_map");
  201. ASSERT_EQ(prim->name(), kPrimListMap->name());
  202. }
  203. TEST_F(TestOps, ListReduceTest) {
  204. auto prim = std::make_shared<Primitive>("list_reduce");
  205. ASSERT_EQ(prim->name(), kPrimListReduce->name());
  206. }
  207. // Arrays
  208. TEST_F(TestOps, ScalarToArrayTest) {
  209. auto prim = std::make_shared<Primitive>("scalar_to_array");
  210. ASSERT_EQ(prim->name(), kPrimScalarToArray->name());
  211. }
  212. TEST_F(TestOps, ArrayToScalarTest) {
  213. auto prim = std::make_shared<Primitive>("array_to_scalar");
  214. ASSERT_EQ(prim->name(), kPrimArrayToScalar->name());
  215. }
  216. TEST_F(TestOps, BroadCastShapeTest) {
  217. auto prim = std::make_shared<Primitive>("broadcast_shape");
  218. ASSERT_EQ(prim->name(), kPrimBroadcastShape->name());
  219. }
  220. TEST_F(TestOps, ArrayMapTest) {
  221. auto prim = std::make_shared<Primitive>("array_map");
  222. ASSERT_EQ(prim->name(), kPrimArrayMap->name());
  223. }
  224. TEST_F(TestOps, ArrayReduceTest) {
  225. auto prim = std::make_shared<Primitive>("array_reduce");
  226. ASSERT_EQ(prim->name(), kPrimArrayReduce->name());
  227. }
  228. TEST_F(TestOps, DistributeTest) {
  229. auto prim = std::make_shared<Primitive>("distribute");
  230. ASSERT_EQ(prim->name(), kPrimDistribute->name());
  231. }
  232. TEST_F(TestOps, TransposeTest) {
  233. auto prim = std::make_shared<Primitive>("Transpose");
  234. ASSERT_EQ(prim->name(), kPrimTranspose->name());
  235. }
  236. TEST_F(TestOps, DotTest) {
  237. auto prim = std::make_shared<Primitive>("dot");
  238. ASSERT_EQ(prim->name(), kPrimDot->name());
  239. }
  240. TEST_F(TestOps, Im2ColTest) {
  241. auto prim = std::make_shared<Primitive>("im2col");
  242. ASSERT_EQ(prim->name(), kPrimIm2Col->name());
  243. }
  244. TEST_F(TestOps, Col2ImTest) {
  245. auto prim = std::make_shared<Primitive>("col2im");
  246. ASSERT_EQ(prim->name(), kPrimCol2Im->name());
  247. }
  248. TEST_F(TestOps, Im2ColV1Test) {
  249. auto prim = std::make_shared<Primitive>("im2col_v1");
  250. ASSERT_EQ(prim->name(), kPrimIm2ColV1->name());
  251. }
  252. TEST_F(TestOps, Col2ImV1Test) {
  253. auto prim = std::make_shared<Primitive>("col2im_v1");
  254. ASSERT_EQ(prim->name(), kPrimCol2ImV1->name());
  255. }
  256. // Statements
  257. TEST_F(TestOps, SwitchTest) {
  258. auto prim = std::make_shared<Primitive>("switch");
  259. ASSERT_EQ(prim->name(), kPrimSwitch->name());
  260. }
  261. TEST_F(TestOps, ReturnTest) {
  262. auto prim = std::make_shared<Primitive>("return");
  263. ASSERT_EQ(prim->name(), kPrimReturn->name());
  264. }
  265. // Miscellaneous
  266. TEST_F(TestOps, IdentityTest) {
  267. auto prim = std::make_shared<Primitive>("identity");
  268. ASSERT_EQ(prim->name(), kPrimIdentity->name());
  269. }
  270. TEST_F(TestOps, ResolveTest) {
  271. auto prim = std::make_shared<Primitive>("resolve");
  272. ASSERT_EQ(prim->name(), kPrimResolve->name());
  273. }
  274. TEST_F(TestOps, PartialTest) {
  275. auto prim = std::make_shared<Primitive>("Partial");
  276. ASSERT_EQ(prim->name(), kPrimPartial->name());
  277. }
  278. TEST_F(TestOps, JTest) {
  279. auto prim = std::make_shared<Primitive>("J");
  280. ASSERT_EQ(prim->name(), kPrimJ->name());
  281. }
  282. TEST_F(TestOps, EmbedTest) {
  283. auto prim = std::make_shared<Primitive>("embed");
  284. ASSERT_EQ(prim->name(), kPrimEmbed->name());
  285. }
  286. TEST_F(TestOps, EnvSetItemTest) {
  287. auto prim = std::make_shared<Primitive>("env_setitem");
  288. ASSERT_EQ(prim->name(), kPrimEnvSetItem->name());
  289. }
  290. TEST_F(TestOps, EnvGetItemTest) {
  291. auto prim = std::make_shared<Primitive>("env_getitem");
  292. ASSERT_EQ(prim->name(), kPrimEnvGetItem->name());
  293. }
  294. TEST_F(TestOps, EnvAddest) {
  295. auto prim = std::make_shared<Primitive>("env_add");
  296. ASSERT_EQ(prim->name(), kPrimEnvAdd->name());
  297. }
  298. // Neural Network
  299. TEST_F(TestOps, Conv2dTest) {
  300. auto prim = std::make_shared<Primitive>("Conv2D");
  301. ASSERT_EQ(prim->name(), kPrimConv2D->name());
  302. }
  303. TEST_F(TestOps, Conv2dAttrTest) {
  304. Primitive prim("Conv2D");
  305. prim.SetAttrs({
  306. {"stride", MakeValue(static_cast<int64_t>(3))},
  307. {"pad", MakeValue(static_cast<int64_t>(1))},
  308. });
  309. ASSERT_EQ(prim.name(), kPrimConv2D->name());
  310. Int64Imm stride(3);
  311. Int64Imm pad(1);
  312. ASSERT_EQ(*prim.GetAttr("stride"), stride);
  313. ASSERT_EQ(*prim.GetAttr("pad"), pad);
  314. }
  315. TEST_F(TestOps, CustomOpAttrTest) {
  316. Primitive prim("CustomOp", true, kPrimTypePyInferShape);
  317. prim.SetAttrs({
  318. {"attr1", MakeValue(static_cast<int64_t>(3))},
  319. {"attr2", MakeValue(static_cast<int64_t>(1))},
  320. });
  321. ASSERT_EQ(prim.name(), std::string("CustomOp"));
  322. ASSERT_EQ(prim.prim_type(), kPrimTypePyInferShape);
  323. auto attrs = prim.attrs();
  324. for (auto attr : attrs) {
  325. std::string prim_name = attr.first;
  326. auto prim_value = attr.second;
  327. std::cout << prim_name << std::endl;
  328. std::cout << prim_value << std::endl;
  329. }
  330. }
  331. TEST_F(TestOps, Conv2dBackpropInputTest) {
  332. auto prim = std::make_shared<Primitive>("Conv2DBackpropInput");
  333. ASSERT_EQ(prim->name(), kPrimConv2DBackpropInput->name());
  334. }
  335. TEST_F(TestOps, Conv2dBackpropFilterTest) {
  336. auto prim = std::make_shared<Primitive>("Conv2DBackpropFilter");
  337. ASSERT_EQ(prim->name(), kPrimConv2DBackpropFilter->name());
  338. }
  339. TEST_F(TestOps, ReluTest) {
  340. auto prim = std::make_shared<Primitive>("ReLU");
  341. ASSERT_EQ(prim->name(), kPrimRelu->name());
  342. }
  343. TEST_F(TestOps, FusedBatchNormTest) {
  344. auto prim = std::make_shared<Primitive>("FusedBatchNorm");
  345. ASSERT_EQ(prim->name(), kPrimFusedBatchNorm->name());
  346. }
  347. TEST_F(TestOps, FusedBatchNormAttrTest) {
  348. Primitive prim("FusedBatchNorm");
  349. prim.SetAttrs({
  350. {"epsilon", MakeValue(0.001f)},
  351. {"momentum", MakeValue(0.1f)},
  352. });
  353. ASSERT_EQ(prim.name(), kPrimFusedBatchNorm->name());
  354. FP32Imm epsilon(0.001f);
  355. FP32Imm momentum(0.1f);
  356. ASSERT_EQ(*prim.GetAttr("epsilon"), epsilon);
  357. ASSERT_EQ(*prim.GetAttr("momentum"), momentum);
  358. }
  359. TEST_F(TestOps, PoolingTest) {
  360. auto prim = std::make_shared<Primitive>("Pooling");
  361. ASSERT_EQ(prim->name(), kPrimPooling->name());
  362. }
  363. TEST_F(TestOps, GetConv2DPrimPyTest) {
  364. auto conv2d_prim = prim::GetPythonOps("conv2d_prim", "gtest_input.pynative");
  365. ASSERT_TRUE(conv2d_prim);
  366. PrimitivePyPtr conv2d_ptr = dyn_cast<PrimitivePy>(conv2d_prim);
  367. ASSERT_TRUE(conv2d_ptr);
  368. if (nullptr != conv2d_ptr) {
  369. MS_LOG(INFO) << "Get PrimitivePyPtr: " << conv2d_ptr->name();
  370. if(!conv2d_ptr->HasComputeFunction()){
  371. MS_LOG(EXCEPTION) << "" << conv2d_ptr->name() << "'s compute function is not implemented";
  372. }
  373. py::object conv2d_pyobj = parse::python_adapter::GetPyFn("gtest_input.pynative", "conv2d_prim");
  374. py::dict opAttrs = py::getattr(conv2d_pyobj, "attrs");
  375. std::unordered_map<std::string, ValuePtr> attrs{};
  376. for (auto item : opAttrs) {
  377. if (!py::isinstance<py::str>(item.first)) {
  378. MS_LOG(EXCEPTION) << "type error in py dict convert";
  379. }
  380. std::string name = py::cast<std::string>(item.first);
  381. MS_LOG(INFO) << "Attr name: " << name;
  382. ValuePtr converted_ret;
  383. parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
  384. MS_LOG(INFO) << "Attr value: " << converted_ret->ToString();
  385. attrs.emplace(name, converted_ret);
  386. }
  387. }
  388. MS_LOG(INFO) << "Finish GetPyFnTest!";
  389. }
  390. } // namespace prim
  391. } // namespace mindspore