You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

composite_test.cc 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <memory>
  17. #include "common/common_test.h"
  18. #include "ir/anf.h"
  19. #include "ir/value.h"
  20. #include "operator/composite/composite.h"
  21. #include "operator/ops.h"
  22. #include "pipeline/static_analysis/prim.h"
  23. #include "pipeline/static_analysis/abstract_function.h"
  24. namespace mindspore {
  25. using Shape = abstract::Shape;
  26. using AbstractScalar = abstract::AbstractScalar;
  27. using AbstractScalarPtr = abstract::AbstractScalarPtr;
  28. using AbstractSlice = abstract::AbstractSlice;
  29. using AbstractSlicePtr = abstract::AbstractSlicePtr;
  30. using AbstractTuple = abstract::AbstractTuple;
  31. using AbstractTuplePtr = abstract::AbstractTuplePtr;
  32. using AbstractTensor = abstract::AbstractTensor;
  33. using AbstractTensorPtr = abstract::AbstractTensorPtr;
  34. using AbstractNone = abstract::AbstractNone;
  35. using AbstractAttribute = abstract::AbstractAttribute;
  36. using AnalysisEngine = abstract::AnalysisEngine;
  37. using AnalysisEnginePtr = abstract::AnalysisEnginePtr;
  38. class TestComposite : public UT::Common {
  39. public:
  40. virtual void SetUp();
  41. virtual void TearDown();
  42. AnalysisEnginePtr engine_;
  43. };
  44. void TestComposite::SetUp() {
  45. // init resource
  46. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager();
  47. engine_ = std::make_shared<AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), graph_manager);
  48. }
  49. void TestComposite::TearDown() {
  50. // destroy resource
  51. }
  52. class UTCompositeUtils {
  53. public:
  54. static AbstractTensorPtr ArrayInt32Of(std::initializer_list<int> shp) {
  55. auto ele = std::make_shared<AbstractScalar>(kAnyValue, kInt32);
  56. return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
  57. }
  58. static FuncGraphPtr MakeFuncGraph(const MetaFuncGraphPtr &metaFuncGraphPtr, size_t nparam) {
  59. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  60. std::vector<AnfNodePtr> inputs;
  61. inputs.push_back(NewValueNode(metaFuncGraphPtr));
  62. for (size_t i = 0; i < nparam; i++) {
  63. inputs.push_back(func_graph->add_parameter());
  64. }
  65. CNodePtr cnode_prim = func_graph->NewCNode(inputs);
  66. inputs.clear();
  67. inputs.push_back(NewValueNode(prim::kPrimReturn));
  68. inputs.push_back(cnode_prim);
  69. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  70. func_graph->set_return(cnode_return);
  71. return func_graph;
  72. }
  73. };
  74. TEST_F(TestComposite, test_TupleSlice_arg_two_numbers) {
  75. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  76. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 3);
  77. AbstractBasePtrList eles;
  78. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  79. size_t tuple_size = 6;
  80. for (size_t i = 0; i < tuple_size; i++) {
  81. eles.push_back(tensor);
  82. }
  83. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  84. auto start_index = std::make_shared<AbstractScalar>(1);
  85. auto stop_index = std::make_shared<AbstractScalar>(5);
  86. AbstractBasePtrList args_spec_list = {tuple_tensor, start_index, stop_index};
  87. try {
  88. engine_->Run(tupleSliceGraphPtr, args_spec_list);
  89. FAIL() << "Excepted exception :Args type is wrong";
  90. } catch (std::runtime_error const &err) {
  91. ASSERT_TRUE(std::string(err.what()).find("TupleSlice input args size should be 2, but got 3") != std::string::npos);
  92. } catch (...) {
  93. FAIL() << "Excepted exception :Args type is wrong";
  94. }
  95. }
  96. TEST_F(TestComposite, test_TupleSlice_arg_one_number) {
  97. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  98. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  99. AbstractBasePtrList eles;
  100. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  101. size_t tuple_size = 6;
  102. for (size_t i = 0; i < tuple_size; i++) {
  103. eles.push_back(tensor);
  104. }
  105. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  106. auto start_index = std::make_shared<AbstractScalar>(1);
  107. AbstractBasePtrList args_spec_list = {tuple_tensor, start_index};
  108. try {
  109. engine_->Run(tupleSliceGraphPtr, args_spec_list);
  110. FAIL() << "Excepted exception :Args type is wrong";
  111. } catch (std::runtime_error const &err) {
  112. ASSERT_TRUE(std::string(err.what()).find("TypeError") != std::string::npos);
  113. } catch (...) {
  114. FAIL() << "Excepted exception :Args type is wrong";
  115. }
  116. }
  117. TEST_F(TestComposite, test_TupleSlice_arg_slice) {
  118. std::shared_ptr<py::scoped_interpreter> env = parse::python_adapter::set_python_scoped();
  119. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  120. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  121. AbstractBasePtrList eles;
  122. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  123. size_t tuple_size = 6;
  124. for (size_t i = 0; i < tuple_size; i++) {
  125. eles.push_back(tensor);
  126. }
  127. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  128. auto start_index = std::make_shared<AbstractScalar>(1);
  129. auto stop_index = std::make_shared<AbstractScalar>(6);
  130. auto step = std::make_shared<AbstractScalar>(2);
  131. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  132. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  133. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
  134. if (ret == nullptr) {
  135. FAIL() << "Cast ret to abstract tuple failed.";
  136. }
  137. size_t real = ret->size();
  138. size_t expect = 3;
  139. ASSERT_EQ(real, expect);
  140. }
  141. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
  142. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  143. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  144. AbstractBasePtrList eles;
  145. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  146. size_t tuple_size = 6;
  147. for (size_t i = 0; i < tuple_size; i++) {
  148. eles.push_back(tensor);
  149. }
  150. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  151. auto start_index = std::make_shared<AbstractScalar>(1);
  152. auto stop_index = std::make_shared<AbstractScalar>(5);
  153. auto step = std::make_shared<AbstractNone>();
  154. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  155. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  156. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
  157. if (ret == nullptr) {
  158. FAIL() << "Cast ret to abstract tuple failed.";
  159. }
  160. size_t real = ret->size();
  161. size_t expect = 4;
  162. ASSERT_EQ(real, expect);
  163. }
  164. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
  165. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  166. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  167. AbstractBasePtrList eles;
  168. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  169. size_t tuple_size = 6;
  170. for (size_t i = 0; i < tuple_size; i++) {
  171. eles.push_back(tensor);
  172. }
  173. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  174. auto start_index = std::make_shared<AbstractNone>();
  175. auto stop_index = std::make_shared<AbstractNone>();
  176. auto step = std::make_shared<AbstractScalar>(-1);
  177. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  178. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  179. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
  180. if (ret == nullptr) {
  181. FAIL() << "Cast ret to abstract tuple failed.";
  182. }
  183. size_t real = ret->size();
  184. size_t expect = 6;
  185. ASSERT_EQ(real, expect);
  186. }
  187. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
  188. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  189. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  190. AbstractBasePtrList eles;
  191. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  192. size_t tuple_size = 6;
  193. for (size_t i = 0; i < tuple_size; i++) {
  194. eles.push_back(tensor);
  195. }
  196. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  197. auto start_index = std::make_shared<AbstractScalar>(-2);
  198. auto stop_index = std::make_shared<AbstractNone>();
  199. auto step = std::make_shared<AbstractScalar>(-1);
  200. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  201. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  202. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
  203. if (ret == nullptr) {
  204. FAIL() << "Cast ret to abstract tuple failed.";
  205. }
  206. size_t real = ret->size();
  207. size_t expect = 5;
  208. ASSERT_EQ(real, expect);
  209. }
  210. TEST_F(TestComposite, test_TensorSliceBySlice) {
  211. MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
  212. FuncGraphPtr tensorSlicePtrGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
  213. AbstractBasePtrList eles;
  214. AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(1);
  215. AbstractScalarPtr stop_index = std::make_shared<AbstractScalar>(6);
  216. AbstractScalarPtr step = std::make_shared<AbstractScalar>(2);
  217. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
  218. AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  219. AbstractBasePtrList args_spec_list = {tensor, slice};
  220. AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred);
  221. if (ret == nullptr) {
  222. FAIL() << "Cast ret to abstract array failed.";
  223. }
  224. AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({3, 7, 8});
  225. ASSERT_EQ(*ret, *expect);
  226. }
  227. TEST_F(TestComposite, test_TensorSliceBySliceTuple) {
  228. MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
  229. FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
  230. AbstractBasePtrList eles;
  231. AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(0);
  232. AbstractScalarPtr stop_index = std::make_shared<AbstractScalar>(6);
  233. AbstractScalarPtr step = std::make_shared<AbstractScalar>(2);
  234. AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  235. eles.push_back(slice);
  236. start_index = std::make_shared<AbstractScalar>(1);
  237. stop_index = std::make_shared<AbstractScalar>(5);
  238. step = std::make_shared<AbstractScalar>(1);
  239. slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  240. eles.push_back(slice);
  241. start_index = std::make_shared<AbstractScalar>(2);
  242. stop_index = std::make_shared<AbstractScalar>(8);
  243. step = std::make_shared<AbstractScalar>(3);
  244. slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  245. eles.push_back(slice);
  246. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
  247. AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
  248. AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
  249. AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
  250. if (ret == nullptr) {
  251. FAIL() << "Cast ret to abstract array failed.";
  252. }
  253. AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({3, 4, 2});
  254. ASSERT_EQ(*ret, *expect);
  255. }
  256. TEST_F(TestComposite, test_TensorSliceBySliceTupleToReduceDimension) {
  257. MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
  258. FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
  259. AbstractBasePtrList eles;
  260. AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(1);
  261. AbstractScalarPtr stop_index = std::make_shared<AbstractScalar>(5);
  262. AbstractScalarPtr step = std::make_shared<AbstractScalar>(2);
  263. AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  264. eles.push_back(slice);
  265. AbstractScalarPtr elem_index = std::make_shared<AbstractScalar>(1);
  266. eles.push_back(elem_index);
  267. start_index = std::make_shared<AbstractScalar>(2);
  268. stop_index = std::make_shared<AbstractScalar>(6);
  269. step = std::make_shared<AbstractScalar>(1);
  270. slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  271. eles.push_back(slice);
  272. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
  273. AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
  274. AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
  275. AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
  276. if (ret == nullptr) {
  277. FAIL() << "Cast ret to abstract array failed.";
  278. }
  279. AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({2, 4});
  280. ASSERT_EQ(*ret, *expect);
  281. }
  282. TEST_F(TestComposite, test_TensorSliceByScalar) {
  283. MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
  284. FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
  285. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
  286. AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(2);
  287. AbstractBasePtrList args_spec_list = {tensor, start_index};
  288. AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
  289. if (ret == nullptr) {
  290. FAIL() << "Cast ret to abstract array failed.";
  291. }
  292. AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({7, 8});
  293. ASSERT_EQ(*ret, *expect);
  294. }
  295. TEST_F(TestComposite, test_TensorSliceByScalarTuple) {
  296. MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
  297. FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
  298. AbstractBasePtrList eles;
  299. AbstractScalarPtr elem_index = std::make_shared<AbstractScalar>(1);
  300. eles.push_back(elem_index);
  301. elem_index = std::make_shared<AbstractScalar>(3);
  302. eles.push_back(elem_index);
  303. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
  304. AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
  305. AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
  306. AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
  307. if (ret == nullptr) {
  308. FAIL() << "Cast ret to abstract array failed.";
  309. }
  310. AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({8});
  311. ASSERT_EQ(*ret, *expect);
  312. }
  313. TEST_F(TestComposite, test_TensorSliceByScalarTupleToScalar) {
  314. MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
  315. FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
  316. AbstractBasePtrList eles;
  317. AbstractScalarPtr elem_index = std::make_shared<AbstractScalar>(3);
  318. eles.push_back(elem_index);
  319. elem_index = std::make_shared<AbstractScalar>(0);
  320. eles.push_back(elem_index);
  321. elem_index = std::make_shared<AbstractScalar>(6);
  322. eles.push_back(elem_index);
  323. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
  324. AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
  325. AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
  326. AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
  327. if (ret == nullptr) {
  328. FAIL() << "Cast ret to abstract array failed.";
  329. }
  330. AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({});
  331. ASSERT_EQ(*ret, *expect);
  332. }
  333. TEST_F(TestComposite, test_UnpackCall_3args) {
  334. MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
  335. FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3);
  336. auto fn_arg= std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
  337. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  338. AbstractBasePtrList eles;
  339. for (size_t i = 0; i < 6; i++) {
  340. eles.push_back(tensor);
  341. }
  342. AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
  343. AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  344. AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  345. AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  346. std::vector<AbstractAttribute> tensor_map{{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
  347. abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
  348. AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict};
  349. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred);
  350. if (ret == nullptr) {
  351. FAIL() << "Cast ret to abstract tuple failed.";
  352. }
  353. size_t real = ret->size();
  354. size_t expect = 9;
  355. ASSERT_EQ(real, expect);
  356. }
  357. TEST_F(TestComposite, test_UnpackCall_5args) {
  358. MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
  359. FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 5);
  360. auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
  361. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  362. AbstractBasePtrList eles;
  363. for (size_t i = 0; i < 6; i++) {
  364. eles.push_back(tensor);
  365. }
  366. AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
  367. AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  368. AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  369. AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  370. std::vector<AbstractAttribute> tensor_map{{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
  371. abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
  372. AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple};
  373. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred);
  374. if (ret == nullptr) {
  375. FAIL() << "Cast ret to abstract tuple failed.";
  376. }
  377. size_t real = ret->size();
  378. size_t expect = 18;
  379. ASSERT_EQ(real, expect);
  380. }
  381. TEST_F(TestComposite, test_ZipOperation) {
  382. MetaFuncGraphPtr zip_op = std::make_shared<prim::ZipOperation>("zip_op");
  383. FuncGraphPtr zip_op_graph = UTCompositeUtils::MakeFuncGraph(zip_op, 1);
  384. AbstractBasePtrList eles;
  385. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  386. size_t tuple_size = 3;
  387. for (size_t i = 0; i < tuple_size; i++) {
  388. eles.push_back(tensor);
  389. }
  390. auto tuple = std::make_shared<AbstractTuple>(eles);
  391. AbstractBasePtrList args_spec_list = {tuple};
  392. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred);
  393. if (ret == nullptr) {
  394. FAIL() << "Cast ret to abstract tuple failed.";
  395. }
  396. size_t real = ret->size();
  397. size_t expect = 3;
  398. ASSERT_EQ(real, expect);
  399. }
  400. } // namespace mindspore