You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

composite_test.cc 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <memory>
  17. #include "common/common_test.h"
  18. #include "ir/anf.h"
  19. #include "ir/value.h"
  20. #include "operator/composite/composite.h"
  21. #include "operator/ops.h"
  22. #include "pipeline/static_analysis/prim.h"
  23. #include "pipeline/static_analysis/abstract_function.h"
  24. #include "debug/trace.h"
  25. namespace mindspore {
  26. using Shape = abstract::Shape;
  27. using AbstractScalar = abstract::AbstractScalar;
  28. using AbstractScalarPtr = abstract::AbstractScalarPtr;
  29. using AbstractSlice = abstract::AbstractSlice;
  30. using AbstractSlicePtr = abstract::AbstractSlicePtr;
  31. using AbstractTuple = abstract::AbstractTuple;
  32. using AbstractTuplePtr = abstract::AbstractTuplePtr;
  33. using AbstractTensor = abstract::AbstractTensor;
  34. using AbstractTensorPtr = abstract::AbstractTensorPtr;
  35. using AbstractNone = abstract::AbstractNone;
  36. using AbstractAttribute = abstract::AbstractAttribute;
  37. using AnalysisEngine = abstract::AnalysisEngine;
  38. using AnalysisEnginePtr = abstract::AnalysisEnginePtr;
  39. class TestComposite : public UT::Common {
  40. public:
  41. virtual void SetUp();
  42. virtual void TearDown();
  43. AnalysisEnginePtr engine_;
  44. };
  45. void TestComposite::SetUp() {
  46. // init resource
  47. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager();
  48. engine_ = std::make_shared<AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), graph_manager);
  49. }
  50. void TestComposite::TearDown() {
  51. // destroy resource
  52. }
  53. class UTCompositeUtils {
  54. public:
  55. static AbstractTensorPtr ArrayInt32Of(std::initializer_list<int> shp) {
  56. auto ele = std::make_shared<AbstractScalar>(kAnyValue, kInt32);
  57. return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
  58. }
  59. static FuncGraphPtr MakeFuncGraph(const MetaFuncGraphPtr &metaFuncGraphPtr, size_t nparam) {
  60. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  61. std::vector<AnfNodePtr> inputs;
  62. inputs.push_back(NewValueNode(metaFuncGraphPtr));
  63. for (size_t i = 0; i < nparam; i++) {
  64. inputs.push_back(func_graph->add_parameter());
  65. }
  66. CNodePtr cnode_prim = func_graph->NewCNode(inputs);
  67. inputs.clear();
  68. inputs.push_back(NewValueNode(prim::kPrimReturn));
  69. inputs.push_back(cnode_prim);
  70. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  71. func_graph->set_return(cnode_return);
  72. return func_graph;
  73. }
  74. };
  75. TEST_F(TestComposite, test_TupleSlice_arg_two_numbers) {
  76. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  77. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 3);
  78. AbstractBasePtrList eles;
  79. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  80. size_t tuple_size = 6;
  81. for (size_t i = 0; i < tuple_size; i++) {
  82. eles.push_back(tensor);
  83. }
  84. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  85. auto start_index = std::make_shared<AbstractScalar>(1);
  86. auto stop_index = std::make_shared<AbstractScalar>(5);
  87. AbstractBasePtrList args_spec_list = {tuple_tensor, start_index, stop_index};
  88. try {
  89. engine_->Run(tupleSliceGraphPtr, args_spec_list);
  90. FAIL() << "Excepted exception :Args type is wrong";
  91. } catch (std::runtime_error const &err) {
  92. ASSERT_TRUE(std::string(err.what()).find("TupleSlice input args size should be 2, but got 3") != std::string::npos);
  93. } catch (...) {
  94. FAIL() << "Excepted exception :Args type is wrong";
  95. }
  96. }
  97. TEST_F(TestComposite, test_TupleSlice_arg_one_number) {
  98. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  99. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  100. AbstractBasePtrList eles;
  101. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  102. size_t tuple_size = 6;
  103. for (size_t i = 0; i < tuple_size; i++) {
  104. eles.push_back(tensor);
  105. }
  106. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  107. auto start_index = std::make_shared<AbstractScalar>(1);
  108. AbstractBasePtrList args_spec_list = {tuple_tensor, start_index};
  109. try {
  110. trace::ClearTraceStack();
  111. engine_->Run(tupleSliceGraphPtr, args_spec_list);
  112. FAIL() << "Excepted exception :Args type is wrong";
  113. } catch (pybind11::type_error const &err) {
  114. ASSERT_TRUE(true);
  115. } catch (...) {
  116. FAIL() << "Excepted exception :Args type is wrong";
  117. }
  118. }
  119. TEST_F(TestComposite, test_TupleSlice_arg_slice) {
  120. std::shared_ptr<py::scoped_interpreter> env = parse::python_adapter::set_python_scoped();
  121. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  122. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  123. AbstractBasePtrList eles;
  124. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  125. size_t tuple_size = 6;
  126. for (size_t i = 0; i < tuple_size; i++) {
  127. eles.push_back(tensor);
  128. }
  129. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  130. auto start_index = std::make_shared<AbstractScalar>(1);
  131. auto stop_index = std::make_shared<AbstractScalar>(6);
  132. auto step = std::make_shared<AbstractScalar>(2);
  133. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  134. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  135. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
  136. if (ret == nullptr) {
  137. FAIL() << "Cast ret to abstract tuple failed.";
  138. }
  139. size_t real = ret->size();
  140. size_t expect = 3;
  141. ASSERT_EQ(real, expect);
  142. }
  143. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
  144. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  145. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  146. AbstractBasePtrList eles;
  147. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  148. size_t tuple_size = 6;
  149. for (size_t i = 0; i < tuple_size; i++) {
  150. eles.push_back(tensor);
  151. }
  152. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  153. auto start_index = std::make_shared<AbstractScalar>(1);
  154. auto stop_index = std::make_shared<AbstractScalar>(5);
  155. auto step = std::make_shared<AbstractNone>();
  156. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  157. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  158. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
  159. if (ret == nullptr) {
  160. FAIL() << "Cast ret to abstract tuple failed.";
  161. }
  162. size_t real = ret->size();
  163. size_t expect = 4;
  164. ASSERT_EQ(real, expect);
  165. }
  166. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
  167. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  168. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  169. AbstractBasePtrList eles;
  170. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  171. size_t tuple_size = 6;
  172. for (size_t i = 0; i < tuple_size; i++) {
  173. eles.push_back(tensor);
  174. }
  175. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  176. auto start_index = std::make_shared<AbstractNone>();
  177. auto stop_index = std::make_shared<AbstractNone>();
  178. auto step = std::make_shared<AbstractScalar>(-1);
  179. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  180. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  181. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
  182. if (ret == nullptr) {
  183. FAIL() << "Cast ret to abstract tuple failed.";
  184. }
  185. size_t real = ret->size();
  186. size_t expect = 6;
  187. ASSERT_EQ(real, expect);
  188. }
  189. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
  190. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  191. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  192. AbstractBasePtrList eles;
  193. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  194. size_t tuple_size = 6;
  195. for (size_t i = 0; i < tuple_size; i++) {
  196. eles.push_back(tensor);
  197. }
  198. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  199. auto start_index = std::make_shared<AbstractScalar>(-2);
  200. auto stop_index = std::make_shared<AbstractNone>();
  201. auto step = std::make_shared<AbstractScalar>(-1);
  202. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  203. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  204. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
  205. if (ret == nullptr) {
  206. FAIL() << "Cast ret to abstract tuple failed.";
  207. }
  208. size_t real = ret->size();
  209. size_t expect = 5;
  210. ASSERT_EQ(real, expect);
  211. }
  212. TEST_F(TestComposite, test_TensorSliceBySlice) {
  213. MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
  214. FuncGraphPtr tensorSlicePtrGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
  215. AbstractBasePtrList eles;
  216. AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(1);
  217. AbstractScalarPtr stop_index = std::make_shared<AbstractScalar>(6);
  218. AbstractScalarPtr step = std::make_shared<AbstractScalar>(2);
  219. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
  220. AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  221. AbstractBasePtrList args_spec_list = {tensor, slice};
  222. AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred);
  223. if (ret == nullptr) {
  224. FAIL() << "Cast ret to abstract array failed.";
  225. }
  226. AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({3, 7, 8});
  227. ASSERT_EQ(*ret, *expect);
  228. }
  229. TEST_F(TestComposite, test_TensorSliceBySliceTuple) {
  230. MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
  231. FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
  232. AbstractBasePtrList eles;
  233. AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(0);
  234. AbstractScalarPtr stop_index = std::make_shared<AbstractScalar>(6);
  235. AbstractScalarPtr step = std::make_shared<AbstractScalar>(2);
  236. AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  237. eles.push_back(slice);
  238. start_index = std::make_shared<AbstractScalar>(1);
  239. stop_index = std::make_shared<AbstractScalar>(5);
  240. step = std::make_shared<AbstractScalar>(1);
  241. slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  242. eles.push_back(slice);
  243. start_index = std::make_shared<AbstractScalar>(2);
  244. stop_index = std::make_shared<AbstractScalar>(8);
  245. step = std::make_shared<AbstractScalar>(3);
  246. slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  247. eles.push_back(slice);
  248. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
  249. AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
  250. AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
  251. AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
  252. if (ret == nullptr) {
  253. FAIL() << "Cast ret to abstract array failed.";
  254. }
  255. AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({3, 4, 2});
  256. ASSERT_EQ(*ret, *expect);
  257. }
  258. TEST_F(TestComposite, test_TensorSliceBySliceTupleToReduceDimension) {
  259. MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
  260. FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
  261. AbstractBasePtrList eles;
  262. AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(1);
  263. AbstractScalarPtr stop_index = std::make_shared<AbstractScalar>(5);
  264. AbstractScalarPtr step = std::make_shared<AbstractScalar>(2);
  265. AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  266. eles.push_back(slice);
  267. AbstractScalarPtr elem_index = std::make_shared<AbstractScalar>(1);
  268. eles.push_back(elem_index);
  269. start_index = std::make_shared<AbstractScalar>(2);
  270. stop_index = std::make_shared<AbstractScalar>(6);
  271. step = std::make_shared<AbstractScalar>(1);
  272. slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  273. eles.push_back(slice);
  274. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
  275. AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
  276. AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
  277. AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
  278. if (ret == nullptr) {
  279. FAIL() << "Cast ret to abstract array failed.";
  280. }
  281. AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({2, 4});
  282. ASSERT_EQ(*ret, *expect);
  283. }
  284. TEST_F(TestComposite, test_TensorSliceByScalar) {
  285. MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
  286. FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
  287. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
  288. AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(2);
  289. AbstractBasePtrList args_spec_list = {tensor, start_index};
  290. AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
  291. if (ret == nullptr) {
  292. FAIL() << "Cast ret to abstract array failed.";
  293. }
  294. AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({7, 8});
  295. ASSERT_EQ(*ret, *expect);
  296. }
  297. TEST_F(TestComposite, test_TensorSliceByScalarTuple) {
  298. MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
  299. FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
  300. AbstractBasePtrList eles;
  301. AbstractScalarPtr elem_index = std::make_shared<AbstractScalar>(1);
  302. eles.push_back(elem_index);
  303. elem_index = std::make_shared<AbstractScalar>(3);
  304. eles.push_back(elem_index);
  305. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
  306. AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
  307. AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
  308. AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
  309. if (ret == nullptr) {
  310. FAIL() << "Cast ret to abstract array failed.";
  311. }
  312. AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({8});
  313. ASSERT_EQ(*ret, *expect);
  314. }
  315. TEST_F(TestComposite, test_TensorSliceByScalarTupleToScalar) {
  316. MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
  317. FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
  318. AbstractBasePtrList eles;
  319. AbstractScalarPtr elem_index = std::make_shared<AbstractScalar>(3);
  320. eles.push_back(elem_index);
  321. elem_index = std::make_shared<AbstractScalar>(0);
  322. eles.push_back(elem_index);
  323. elem_index = std::make_shared<AbstractScalar>(6);
  324. eles.push_back(elem_index);
  325. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
  326. AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
  327. AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
  328. AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
  329. if (ret == nullptr) {
  330. FAIL() << "Cast ret to abstract array failed.";
  331. }
  332. AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({});
  333. ASSERT_EQ(*ret, *expect);
  334. }
  335. TEST_F(TestComposite, test_UnpackCall_3args) {
  336. MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
  337. FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3);
  338. auto fn_arg= std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
  339. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  340. AbstractBasePtrList eles;
  341. for (size_t i = 0; i < 6; i++) {
  342. eles.push_back(tensor);
  343. }
  344. AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
  345. AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  346. AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  347. AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  348. std::vector<AbstractAttribute> tensor_map{{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
  349. abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
  350. AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict};
  351. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred);
  352. if (ret == nullptr) {
  353. FAIL() << "Cast ret to abstract tuple failed.";
  354. }
  355. size_t real = ret->size();
  356. size_t expect = 9;
  357. ASSERT_EQ(real, expect);
  358. }
  359. TEST_F(TestComposite, test_UnpackCall_5args) {
  360. MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
  361. FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 5);
  362. auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
  363. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  364. AbstractBasePtrList eles;
  365. for (size_t i = 0; i < 6; i++) {
  366. eles.push_back(tensor);
  367. }
  368. AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
  369. AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  370. AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  371. AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  372. std::vector<AbstractAttribute> tensor_map{{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
  373. abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
  374. AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple};
  375. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred);
  376. if (ret == nullptr) {
  377. FAIL() << "Cast ret to abstract tuple failed.";
  378. }
  379. size_t real = ret->size();
  380. size_t expect = 18;
  381. ASSERT_EQ(real, expect);
  382. }
  383. TEST_F(TestComposite, test_ZipOperation) {
  384. MetaFuncGraphPtr zip_op = std::make_shared<prim::ZipOperation>("zip_op");
  385. FuncGraphPtr zip_op_graph = UTCompositeUtils::MakeFuncGraph(zip_op, 1);
  386. AbstractBasePtrList eles;
  387. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  388. size_t tuple_size = 3;
  389. for (size_t i = 0; i < tuple_size; i++) {
  390. eles.push_back(tensor);
  391. }
  392. auto tuple = std::make_shared<AbstractTuple>(eles);
  393. AbstractBasePtrList args_spec_list = {tuple};
  394. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred);
  395. if (ret == nullptr) {
  396. FAIL() << "Cast ret to abstract tuple failed.";
  397. }
  398. size_t real = ret->size();
  399. size_t expect = 3;
  400. ASSERT_EQ(real, expect);
  401. }
  402. } // namespace mindspore