You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

composite_test.cc 15 kB

4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <memory>
  17. #include "common/common_test.h"
  18. #include "ir/anf.h"
  19. #include "ir/value.h"
  20. #include "frontend/operator/composite/composite.h"
  21. #include "frontend/operator/ops.h"
  22. #include "pipeline/jit/static_analysis/prim.h"
  23. #include "abstract/abstract_function.h"
  24. #include "debug/trace.h"
  25. namespace mindspore {
  26. using Shape = abstract::Shape;
  27. using AbstractScalar = abstract::AbstractScalar;
  28. using AbstractScalarPtr = abstract::AbstractScalarPtr;
  29. using AbstractSlice = abstract::AbstractSlice;
  30. using AbstractSlicePtr = abstract::AbstractSlicePtr;
  31. using AbstractTuple = abstract::AbstractTuple;
  32. using AbstractTuplePtr = abstract::AbstractTuplePtr;
  33. using AbstractTensor = abstract::AbstractTensor;
  34. using AbstractTensorPtr = abstract::AbstractTensorPtr;
  35. using AbstractNone = abstract::AbstractNone;
  36. using AbstractAttribute = abstract::AbstractAttribute;
  37. using AnalysisEngine = abstract::AnalysisEngine;
  38. using AnalysisEnginePtr = abstract::AnalysisEnginePtr;
  39. class TestComposite : public UT::Common {
  40. public:
  41. virtual void SetUp();
  42. virtual void TearDown();
  43. AnalysisEnginePtr engine_;
  44. };
  45. void TestComposite::SetUp() {
  46. // init resource
  47. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager();
  48. engine_ = std::make_shared<AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), graph_manager);
  49. }
  50. void TestComposite::TearDown() {
  51. // destroy resource
  52. }
  53. class UTCompositeUtils {
  54. public:
  55. static AbstractTensorPtr ArrayInt32Of(std::initializer_list<int64_t> shp) {
  56. auto ele = std::make_shared<AbstractScalar>(kAnyValue, kInt64);
  57. return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
  58. }
  59. static FuncGraphPtr MakeFuncGraph(const MetaFuncGraphPtr &metaFuncGraphPtr, size_t nparam) {
  60. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  61. std::vector<AnfNodePtr> inputs;
  62. inputs.push_back(NewValueNode(metaFuncGraphPtr));
  63. for (size_t i = 0; i < nparam; i++) {
  64. inputs.push_back(func_graph->add_parameter());
  65. }
  66. CNodePtr cnode_prim = func_graph->NewCNode(inputs);
  67. inputs.clear();
  68. inputs.push_back(NewValueNode(prim::kPrimReturn));
  69. inputs.push_back(cnode_prim);
  70. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  71. func_graph->set_return(cnode_return);
  72. return func_graph;
  73. }
  74. };
  75. TEST_F(TestComposite, test_TupleSlice_arg_two_numbers) {
  76. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  77. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 3);
  78. AbstractBasePtrList eles;
  79. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  80. size_t tuple_size = 6;
  81. for (size_t i = 0; i < tuple_size; i++) {
  82. eles.push_back(tensor);
  83. }
  84. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  85. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  86. auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(5));
  87. AbstractBasePtrList args_spec_list = {tuple_tensor, start_index, stop_index};
  88. try {
  89. engine_->Run(tupleSliceGraphPtr, args_spec_list);
  90. FAIL() << "Excepted exception :Args type is wrong";
  91. } catch (std::runtime_error const &err) {
  92. ASSERT_TRUE(std::string(err.what()).find("For 'TupleSlice', the number of input should be 2, but got 3") !=
  93. std::string::npos);
  94. } catch (...) {
  95. FAIL() << "Excepted exception :Args type is wrong";
  96. }
  97. }
  98. TEST_F(TestComposite, test_TupleSlice_arg_one_number) {
  99. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  100. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  101. AbstractBasePtrList eles;
  102. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  103. size_t tuple_size = 6;
  104. for (size_t i = 0; i < tuple_size; i++) {
  105. eles.push_back(tensor);
  106. }
  107. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  108. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  109. AbstractBasePtrList args_spec_list = {tuple_tensor, start_index};
  110. try {
  111. trace::ClearTraceStack();
  112. engine_->Run(tupleSliceGraphPtr, args_spec_list);
  113. FAIL() << "Excepted exception: Args type is wrong";
  114. } catch (pybind11::type_error const &err) {
  115. ASSERT_TRUE(true);
  116. } catch (std::runtime_error const &err) {
  117. if (std::strstr(err.what(), "TypeError") != nullptr) {
  118. ASSERT_TRUE(true);
  119. } else {
  120. FAIL() << "Excepted exception: Args type is wrong, message: " << err.what();
  121. }
  122. } catch (...) {
  123. FAIL() << "Excepted exception: Args type is wrong";
  124. }
  125. }
  126. TEST_F(TestComposite, test_TupleSlice_arg_slice) {
  127. std::shared_ptr<py::scoped_interpreter> env = parse::python_adapter::set_python_scoped();
  128. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  129. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  130. AbstractBasePtrList eles;
  131. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  132. size_t tuple_size = 6;
  133. for (size_t i = 0; i < tuple_size; i++) {
  134. eles.push_back(tensor);
  135. }
  136. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  137. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  138. auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(6));
  139. auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(2));
  140. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  141. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  142. AbstractTuplePtr ret =
  143. dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
  144. if (ret == nullptr) {
  145. FAIL() << "Cast ret to abstract tuple failed.";
  146. }
  147. size_t real = ret->size();
  148. size_t expect = 3;
  149. ASSERT_EQ(real, expect);
  150. }
  151. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
  152. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  153. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  154. AbstractBasePtrList eles;
  155. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  156. size_t tuple_size = 6;
  157. for (size_t i = 0; i < tuple_size; i++) {
  158. eles.push_back(tensor);
  159. }
  160. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  161. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  162. auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(5));
  163. auto step = std::make_shared<AbstractNone>();
  164. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  165. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  166. AbstractTuplePtr ret =
  167. dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
  168. if (ret == nullptr) {
  169. FAIL() << "Cast ret to abstract tuple failed.";
  170. }
  171. size_t real = ret->size();
  172. size_t expect = 4;
  173. ASSERT_EQ(real, expect);
  174. }
  175. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
  176. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  177. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  178. AbstractBasePtrList eles;
  179. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  180. size_t tuple_size = 6;
  181. for (size_t i = 0; i < tuple_size; i++) {
  182. eles.push_back(tensor);
  183. }
  184. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  185. auto start_index = std::make_shared<AbstractNone>();
  186. auto stop_index = std::make_shared<AbstractNone>();
  187. auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
  188. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  189. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  190. AbstractTuplePtr ret =
  191. dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
  192. if (ret == nullptr) {
  193. FAIL() << "Cast ret to abstract tuple failed.";
  194. }
  195. size_t real = ret->size();
  196. size_t expect = 6;
  197. ASSERT_EQ(real, expect);
  198. }
  199. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
  200. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  201. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  202. AbstractBasePtrList eles;
  203. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  204. size_t tuple_size = 6;
  205. for (size_t i = 0; i < tuple_size; i++) {
  206. eles.push_back(tensor);
  207. }
  208. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  209. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(-2));
  210. auto stop_index = std::make_shared<AbstractNone>();
  211. auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
  212. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  213. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  214. AbstractTuplePtr ret =
  215. dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
  216. if (ret == nullptr) {
  217. FAIL() << "Cast ret to abstract tuple failed.";
  218. }
  219. size_t real = ret->size();
  220. size_t expect = 5;
  221. ASSERT_EQ(real, expect);
  222. }
  223. TEST_F(TestComposite, test_UnpackCall_3args) {
  224. MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
  225. FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3);
  226. auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
  227. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  228. AbstractBasePtrList eles;
  229. for (size_t i = 0; i < 6; i++) {
  230. eles.push_back(tensor);
  231. }
  232. AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
  233. AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  234. AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  235. AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  236. std::vector<AbstractAttribute> tensor_map{{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
  237. abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
  238. AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict};
  239. AbstractTuplePtr ret =
  240. dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).eval_result->abstract());
  241. if (ret == nullptr) {
  242. FAIL() << "Cast ret to abstract tuple failed.";
  243. }
  244. size_t real = ret->size();
  245. size_t expect = 9;
  246. ASSERT_EQ(real, expect);
  247. }
  248. TEST_F(TestComposite, test_UnpackCall_5args) {
  249. MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
  250. FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 5);
  251. auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
  252. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  253. AbstractBasePtrList eles;
  254. for (size_t i = 0; i < 6; i++) {
  255. eles.push_back(tensor);
  256. }
  257. AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
  258. AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  259. AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  260. AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  261. std::vector<AbstractAttribute> tensor_map{{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
  262. abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
  263. AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple};
  264. AbstractTuplePtr ret =
  265. dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).eval_result->abstract());
  266. if (ret == nullptr) {
  267. FAIL() << "Cast ret to abstract tuple failed.";
  268. }
  269. size_t real = ret->size();
  270. size_t expect = 18;
  271. ASSERT_EQ(real, expect);
  272. }
  273. TEST_F(TestComposite, test_ZipOperation) {
  274. MetaFuncGraphPtr zip_op = std::make_shared<prim::ZipOperation>("zip_op");
  275. FuncGraphPtr zip_op_graph = UTCompositeUtils::MakeFuncGraph(zip_op, 1);
  276. AbstractBasePtrList eles;
  277. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  278. size_t tuple_size = 3;
  279. for (size_t i = 0; i < tuple_size; i++) {
  280. eles.push_back(tensor);
  281. }
  282. auto tuple = std::make_shared<AbstractTuple>(eles);
  283. AbstractBasePtrList args_spec_list = {tuple};
  284. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).eval_result->abstract());
  285. if (ret == nullptr) {
  286. FAIL() << "Cast ret to abstract tuple failed.";
  287. }
  288. size_t real = ret->size();
  289. size_t expect = 3;
  290. ASSERT_EQ(real, expect);
  291. }
  292. /// Feature: Shard operation.
  293. /// Description: Test the func_graph generation of Shard op and the inference of the Shard caller.
  294. /// Expectation: Generate and the infer successfully.
  295. TEST_F(TestComposite, test_shard) {
  296. // Make origin func_graph which includes a relu node.
  297. FuncGraphPtr origin_func_graph = std::make_shared<FuncGraph>();
  298. std::vector<AnfNodePtr> inputs;
  299. inputs.push_back(NewValueNode(prim::kPrimRelu));
  300. inputs.push_back(origin_func_graph->add_parameter());
  301. CNodePtr relu = origin_func_graph->NewCNode(inputs);
  302. inputs.clear();
  303. inputs.push_back(NewValueNode(prim::kPrimReturn));
  304. inputs.push_back(relu);
  305. CNodePtr origin_return = origin_func_graph->NewCNode(inputs);
  306. origin_func_graph->set_return(origin_return);
  307. // Make the func_graph which includes a Shard meta_func_graph.
  308. FuncGraphPtr shard_func_graph = std::make_shared<FuncGraph>();
  309. MetaFuncGraphPtr shard_op = std::make_shared<prim::Shard>("shard_op");
  310. inputs.clear();
  311. inputs.push_back(NewValueNode(shard_op));
  312. inputs.push_back(NewValueNode(origin_func_graph));
  313. for (size_t i = 0; i < 4; ++i) {
  314. inputs.push_back(NewValueNode(MakeValue(0)));
  315. }
  316. CNodePtr shard = shard_func_graph->NewCNode(inputs);
  317. inputs.clear();
  318. inputs.push_back(shard);
  319. inputs.push_back(shard_func_graph->add_parameter());
  320. CNodePtr shard_user = shard_func_graph->NewCNode(inputs);
  321. inputs.clear();
  322. inputs.push_back(NewValueNode(prim::kPrimReturn));
  323. inputs.push_back(shard_user);
  324. CNodePtr shard_return = shard_func_graph->NewCNode(inputs);
  325. shard_func_graph->set_return(shard_return);
  326. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  327. AbstractBasePtrList args_spec_list = {tensor};
  328. auto ret = engine_->Run(shard_func_graph, args_spec_list).eval_result->abstract();
  329. ASSERT_NE(ret, nullptr);
  330. ASSERT_TRUE(ret->isa<abstract::AbstractTensor>());
  331. auto build_shape = ret->BuildShape();
  332. EXPECT_TRUE(build_shape->isa<abstract::Shape>());
  333. auto shape = build_shape->cast<abstract::ShapePtr>();
  334. ASSERT_EQ(shape->shape(), std::vector<int64_t>({2, 3, 4}));
  335. }
  336. } // namespace mindspore