You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

composite_test.cc 22 kB

4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <memory>
  17. #include "common/common_test.h"
  18. #include "ir/anf.h"
  19. #include "ir/value.h"
  20. #include "frontend/operator/composite/composite.h"
  21. #include "frontend/operator/ops.h"
  22. #include "pipeline/jit/static_analysis/prim.h"
  23. #include "abstract/abstract_function.h"
  24. #include "debug/trace.h"
  25. namespace mindspore {
  26. using Shape = abstract::Shape;
  27. using AbstractScalar = abstract::AbstractScalar;
  28. using AbstractScalarPtr = abstract::AbstractScalarPtr;
  29. using AbstractSlice = abstract::AbstractSlice;
  30. using AbstractSlicePtr = abstract::AbstractSlicePtr;
  31. using AbstractTuple = abstract::AbstractTuple;
  32. using AbstractTuplePtr = abstract::AbstractTuplePtr;
  33. using AbstractList = abstract::AbstractList;
  34. using AbstractListPtr = abstract::AbstractListPtr;
  35. using AbstractTensor = abstract::AbstractTensor;
  36. using AbstractTensorPtr = abstract::AbstractTensorPtr;
  37. using AbstractNone = abstract::AbstractNone;
  38. using AbstractAttribute = abstract::AbstractAttribute;
  39. using AnalysisEngine = abstract::AnalysisEngine;
  40. using AnalysisEnginePtr = abstract::AnalysisEnginePtr;
  41. class TestComposite : public UT::Common {
  42. public:
  43. virtual void SetUp();
  44. virtual void TearDown();
  45. AnalysisEnginePtr engine_;
  46. };
  47. void TestComposite::SetUp() {
  48. // init resource
  49. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager();
  50. engine_ = std::make_shared<AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), graph_manager);
  51. }
  52. void TestComposite::TearDown() {
  53. // destroy resource
  54. }
  55. class UTCompositeUtils {
  56. public:
  57. static AbstractTensorPtr ArrayInt32Of(std::initializer_list<int64_t> shp) {
  58. auto ele = std::make_shared<AbstractScalar>(kAnyValue, kInt64);
  59. return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
  60. }
  61. static FuncGraphPtr MakeFuncGraph(const MetaFuncGraphPtr &metaFuncGraphPtr, size_t nparam) {
  62. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  63. std::vector<AnfNodePtr> inputs;
  64. inputs.push_back(NewValueNode(metaFuncGraphPtr));
  65. for (size_t i = 0; i < nparam; i++) {
  66. inputs.push_back(func_graph->add_parameter());
  67. }
  68. CNodePtr cnode_prim = func_graph->NewCNode(inputs);
  69. inputs.clear();
  70. inputs.push_back(NewValueNode(prim::kPrimReturn));
  71. inputs.push_back(cnode_prim);
  72. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  73. func_graph->set_return(cnode_return);
  74. return func_graph;
  75. }
  76. };
  77. TEST_F(TestComposite, test_TupleSlice_arg_two_numbers) {
  78. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  79. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 3);
  80. AbstractBasePtrList eles;
  81. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  82. size_t tuple_size = 6;
  83. for (size_t i = 0; i < tuple_size; i++) {
  84. eles.push_back(tensor);
  85. }
  86. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  87. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  88. auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(5));
  89. AbstractBasePtrList args_spec_list = {tuple_tensor, start_index, stop_index};
  90. try {
  91. engine_->Run(tupleSliceGraphPtr, args_spec_list);
  92. FAIL() << "Excepted exception :Args type is wrong";
  93. } catch (std::runtime_error const &err) {
  94. ASSERT_TRUE(std::string(err.what()).find("For 'TupleSlice', the number of input should be 2, but got 3") !=
  95. std::string::npos);
  96. } catch (...) {
  97. FAIL() << "Excepted exception :Args type is wrong";
  98. }
  99. }
  100. TEST_F(TestComposite, test_TupleSlice_arg_one_number) {
  101. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  102. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  103. AbstractBasePtrList eles;
  104. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  105. size_t tuple_size = 6;
  106. for (size_t i = 0; i < tuple_size; i++) {
  107. eles.push_back(tensor);
  108. }
  109. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  110. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  111. AbstractBasePtrList args_spec_list = {tuple_tensor, start_index};
  112. try {
  113. trace::ClearTraceStack();
  114. engine_->Run(tupleSliceGraphPtr, args_spec_list);
  115. FAIL() << "Excepted exception: Args type is wrong";
  116. } catch (pybind11::type_error const &err) {
  117. ASSERT_TRUE(true);
  118. } catch (std::runtime_error const &err) {
  119. if (std::strstr(err.what(), "TypeError") != nullptr) {
  120. ASSERT_TRUE(true);
  121. } else {
  122. FAIL() << "Excepted exception: Args type is wrong, message: " << err.what();
  123. }
  124. } catch (...) {
  125. FAIL() << "Excepted exception: Args type is wrong";
  126. }
  127. }
  128. TEST_F(TestComposite, test_TupleSlice_arg_slice) {
  129. std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
  130. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  131. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  132. AbstractBasePtrList eles;
  133. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  134. size_t tuple_size = 6;
  135. for (size_t i = 0; i < tuple_size; i++) {
  136. eles.push_back(tensor);
  137. }
  138. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  139. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  140. auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(6));
  141. auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(2));
  142. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  143. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  144. AbstractTuplePtr ret =
  145. dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
  146. if (ret == nullptr) {
  147. FAIL() << "Cast ret to abstract tuple failed.";
  148. }
  149. size_t real = ret->size();
  150. size_t expect = 3;
  151. ASSERT_EQ(real, expect);
  152. }
  153. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
  154. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  155. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  156. AbstractBasePtrList eles;
  157. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  158. size_t tuple_size = 6;
  159. for (size_t i = 0; i < tuple_size; i++) {
  160. eles.push_back(tensor);
  161. }
  162. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  163. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  164. auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(5));
  165. auto step = std::make_shared<AbstractNone>();
  166. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  167. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  168. AbstractTuplePtr ret =
  169. dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
  170. if (ret == nullptr) {
  171. FAIL() << "Cast ret to abstract tuple failed.";
  172. }
  173. size_t real = ret->size();
  174. size_t expect = 4;
  175. ASSERT_EQ(real, expect);
  176. }
  177. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
  178. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  179. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  180. AbstractBasePtrList eles;
  181. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  182. size_t tuple_size = 6;
  183. for (size_t i = 0; i < tuple_size; i++) {
  184. eles.push_back(tensor);
  185. }
  186. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  187. auto start_index = std::make_shared<AbstractNone>();
  188. auto stop_index = std::make_shared<AbstractNone>();
  189. auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
  190. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  191. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  192. AbstractTuplePtr ret =
  193. dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
  194. if (ret == nullptr) {
  195. FAIL() << "Cast ret to abstract tuple failed.";
  196. }
  197. size_t real = ret->size();
  198. size_t expect = 6;
  199. ASSERT_EQ(real, expect);
  200. }
  201. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
  202. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  203. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  204. AbstractBasePtrList eles;
  205. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  206. size_t tuple_size = 6;
  207. for (size_t i = 0; i < tuple_size; i++) {
  208. eles.push_back(tensor);
  209. }
  210. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  211. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(-2));
  212. auto stop_index = std::make_shared<AbstractNone>();
  213. auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
  214. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  215. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  216. AbstractTuplePtr ret =
  217. dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
  218. if (ret == nullptr) {
  219. FAIL() << "Cast ret to abstract tuple failed.";
  220. }
  221. size_t real = ret->size();
  222. size_t expect = 5;
  223. ASSERT_EQ(real, expect);
  224. }
  225. /// Feature: Test list slice
  226. /// Description: The second input is a scalar
  227. /// Expectation: Throw type error
  228. TEST_F(TestComposite, test_ListSlice_arg_one_number) {
  229. MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
  230. FuncGraphPtr list_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 3);
  231. AbstractBasePtrList eles;
  232. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  233. size_t list_size = 6;
  234. for (size_t i = 0; i < list_size; i++) {
  235. eles.push_back(tensor);
  236. }
  237. auto list_tensor = std::make_shared<AbstractList>(eles);
  238. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  239. AbstractBasePtrList args_spec_list = {list_tensor, start_index};
  240. try {
  241. trace::ClearTraceStack();
  242. engine_->Run(list_graph, args_spec_list);
  243. FAIL() << "Excepted exception: Args type is wrong";
  244. } catch (pybind11::type_error const &err) {
  245. ASSERT_TRUE(true);
  246. } catch (std::runtime_error const &err) {
  247. if (std::strstr(err.what(), "TypeError") != nullptr) {
  248. ASSERT_TRUE(true);
  249. } else {
  250. FAIL() << "Excepted exception: Args type is wrong, message: " << err.what();
  251. }
  252. } catch (...) {
  253. FAIL() << "Excepted exception: Args type is wrong";
  254. }
  255. }
  256. /// Feature: Test list slice
  257. /// Description: Test List slice
  258. /// Expectation: No Expectation
  259. TEST_F(TestComposite, test_ListSlice_arg_slice) {
  260. std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
  261. MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
  262. FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
  263. AbstractBasePtrList eles;
  264. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  265. size_t list_size = 6;
  266. for (size_t i = 0; i < list_size; i++) {
  267. eles.push_back(tensor);
  268. }
  269. auto list_tensor = std::make_shared<AbstractList>(eles);
  270. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  271. auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(6));
  272. auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(2));
  273. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  274. AbstractBasePtrList args_spec_list = {list_tensor, slice};
  275. AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
  276. if (ret == nullptr) {
  277. FAIL() << "Cast ret to abstract list failed.";
  278. }
  279. size_t real = ret->size();
  280. size_t expect = 3;
  281. ASSERT_EQ(real, expect);
  282. }
  283. /// Feature: Test list slice
  284. /// Description: Test List slice the step is none
  285. /// Expectation: No Expectation
  286. TEST_F(TestComposite, test_ListSlice_arg_slice_step_none) {
  287. MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
  288. FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
  289. AbstractBasePtrList eles;
  290. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  291. size_t list_size = 6;
  292. for (size_t i = 0; i < list_size; i++) {
  293. eles.push_back(tensor);
  294. }
  295. auto list_tensor = std::make_shared<AbstractList>(eles);
  296. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  297. auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(5));
  298. auto step = std::make_shared<AbstractNone>();
  299. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  300. AbstractBasePtrList args_spec_list = {list_tensor, slice};
  301. AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
  302. if (ret == nullptr) {
  303. FAIL() << "Cast ret to abstract list failed.";
  304. }
  305. size_t real = ret->size();
  306. size_t expect = 4;
  307. ASSERT_EQ(real, expect);
  308. }
  309. /// Feature: Test list slice
  310. /// Description: Test List slice the step is negative
  311. /// Expectation: No Expectation
  312. TEST_F(TestComposite, test_ListSlice_arg_slice_step_negative) {
  313. MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
  314. FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
  315. AbstractBasePtrList eles;
  316. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  317. size_t list_size = 6;
  318. for (size_t i = 0; i < list_size; i++) {
  319. eles.push_back(tensor);
  320. }
  321. auto list_tensor = std::make_shared<AbstractList>(eles);
  322. auto start_index = std::make_shared<AbstractNone>();
  323. auto stop_index = std::make_shared<AbstractNone>();
  324. auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
  325. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  326. AbstractBasePtrList args_spec_list = {list_tensor, slice};
  327. AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
  328. if (ret == nullptr) {
  329. FAIL() << "Cast ret to abstract list failed.";
  330. }
  331. size_t real = ret->size();
  332. size_t expect = 6;
  333. ASSERT_EQ(real, expect);
  334. }
  335. /// Feature: Test list slice
  336. /// Description: Test List slice the step is positive
  337. /// Expectation: No Expectation
  338. TEST_F(TestComposite, test_ListSlice_arg_slice_step_positive) {
  339. MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
  340. FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
  341. AbstractBasePtrList eles;
  342. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  343. size_t list_size = 6;
  344. for (size_t i = 0; i < list_size; i++) {
  345. eles.push_back(tensor);
  346. }
  347. auto list_tensor = std::make_shared<AbstractList>(eles);
  348. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(-2));
  349. auto stop_index = std::make_shared<AbstractNone>();
  350. auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
  351. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  352. AbstractBasePtrList args_spec_list = {list_tensor, slice};
  353. AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
  354. if (ret == nullptr) {
  355. FAIL() << "Cast ret to abstract list failed.";
  356. }
  357. size_t real = ret->size();
  358. size_t expect = 5;
  359. ASSERT_EQ(real, expect);
  360. }
  361. TEST_F(TestComposite, test_UnpackCall_3args) {
  362. MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
  363. FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3);
  364. auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
  365. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  366. AbstractBasePtrList eles;
  367. for (size_t i = 0; i < 6; i++) {
  368. eles.push_back(tensor);
  369. }
  370. AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
  371. AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  372. AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  373. AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  374. std::vector<AbstractAttribute> tensor_map{{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
  375. abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
  376. AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict};
  377. AbstractTuplePtr ret =
  378. dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).eval_result->abstract());
  379. if (ret == nullptr) {
  380. FAIL() << "Cast ret to abstract tuple failed.";
  381. }
  382. size_t real = ret->size();
  383. size_t expect = 9;
  384. ASSERT_EQ(real, expect);
  385. }
  386. TEST_F(TestComposite, test_UnpackCall_5args) {
  387. MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
  388. FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 5);
  389. auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
  390. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  391. AbstractBasePtrList eles;
  392. for (size_t i = 0; i < 6; i++) {
  393. eles.push_back(tensor);
  394. }
  395. AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
  396. AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  397. AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  398. AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  399. std::vector<AbstractAttribute> tensor_map{{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
  400. abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
  401. AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple};
  402. AbstractTuplePtr ret =
  403. dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).eval_result->abstract());
  404. if (ret == nullptr) {
  405. FAIL() << "Cast ret to abstract tuple failed.";
  406. }
  407. size_t real = ret->size();
  408. size_t expect = 18;
  409. ASSERT_EQ(real, expect);
  410. }
  411. TEST_F(TestComposite, test_ZipOperation) {
  412. MetaFuncGraphPtr zip_op = std::make_shared<prim::ZipOperation>("zip_op");
  413. FuncGraphPtr zip_op_graph = UTCompositeUtils::MakeFuncGraph(zip_op, 1);
  414. AbstractBasePtrList eles;
  415. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  416. size_t tuple_size = 3;
  417. for (size_t i = 0; i < tuple_size; i++) {
  418. eles.push_back(tensor);
  419. }
  420. auto tuple = std::make_shared<AbstractTuple>(eles);
  421. AbstractBasePtrList args_spec_list = {tuple};
  422. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).eval_result->abstract());
  423. if (ret == nullptr) {
  424. FAIL() << "Cast ret to abstract tuple failed.";
  425. }
  426. size_t real = ret->size();
  427. size_t expect = 3;
  428. ASSERT_EQ(real, expect);
  429. }
  430. /// Feature: Shard operation.
  431. /// Description: Test the func_graph generation of Shard op and the inference of the Shard caller.
  432. /// Expectation: Generate and the infer successfully.
  433. TEST_F(TestComposite, test_shard) {
  434. // Make origin func_graph which includes a relu node.
  435. FuncGraphPtr origin_func_graph = std::make_shared<FuncGraph>();
  436. std::vector<AnfNodePtr> inputs;
  437. inputs.push_back(NewValueNode(prim::kPrimRelu));
  438. inputs.push_back(origin_func_graph->add_parameter());
  439. CNodePtr relu = origin_func_graph->NewCNode(inputs);
  440. inputs.clear();
  441. inputs.push_back(NewValueNode(prim::kPrimReturn));
  442. inputs.push_back(relu);
  443. CNodePtr origin_return = origin_func_graph->NewCNode(inputs);
  444. origin_func_graph->set_return(origin_return);
  445. // Make the func_graph which includes a Shard meta_func_graph.
  446. FuncGraphPtr shard_func_graph = std::make_shared<FuncGraph>();
  447. MetaFuncGraphPtr shard_op = std::make_shared<prim::Shard>("shard_op");
  448. inputs.clear();
  449. inputs.push_back(NewValueNode(shard_op));
  450. inputs.push_back(NewValueNode(origin_func_graph));
  451. for (size_t i = 0; i < 4; ++i) {
  452. inputs.push_back(NewValueNode(MakeValue(0)));
  453. }
  454. CNodePtr shard = shard_func_graph->NewCNode(inputs);
  455. inputs.clear();
  456. inputs.push_back(shard);
  457. inputs.push_back(shard_func_graph->add_parameter());
  458. CNodePtr shard_user = shard_func_graph->NewCNode(inputs);
  459. inputs.clear();
  460. inputs.push_back(NewValueNode(prim::kPrimReturn));
  461. inputs.push_back(shard_user);
  462. CNodePtr shard_return = shard_func_graph->NewCNode(inputs);
  463. shard_func_graph->set_return(shard_return);
  464. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  465. AbstractBasePtrList args_spec_list = {tensor};
  466. auto ret = engine_->Run(shard_func_graph, args_spec_list).eval_result->abstract();
  467. ASSERT_NE(ret, nullptr);
  468. ASSERT_TRUE(ret->isa<abstract::AbstractTensor>());
  469. auto build_shape = ret->BuildShape();
  470. EXPECT_TRUE(build_shape->isa<abstract::Shape>());
  471. auto shape = build_shape->cast<abstract::ShapePtr>();
  472. ASSERT_EQ(shape->shape(), std::vector<int64_t>({2, 3, 4}));
  473. }
  474. } // namespace mindspore