You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

composite_test.cc 13 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <memory>
  17. #include "common/common_test.h"
  18. #include "ir/anf.h"
  19. #include "ir/value.h"
  20. #include "frontend/operator/composite/composite.h"
  21. #include "frontend/operator/ops.h"
  22. #include "pipeline/jit/static_analysis/prim.h"
  23. #include "abstract/abstract_function.h"
  24. #include "debug/trace.h"
  25. namespace mindspore {
  26. using Shape = abstract::Shape;
  27. using AbstractScalar = abstract::AbstractScalar;
  28. using AbstractScalarPtr = abstract::AbstractScalarPtr;
  29. using AbstractSlice = abstract::AbstractSlice;
  30. using AbstractSlicePtr = abstract::AbstractSlicePtr;
  31. using AbstractTuple = abstract::AbstractTuple;
  32. using AbstractTuplePtr = abstract::AbstractTuplePtr;
  33. using AbstractTensor = abstract::AbstractTensor;
  34. using AbstractTensorPtr = abstract::AbstractTensorPtr;
  35. using AbstractNone = abstract::AbstractNone;
  36. using AbstractAttribute = abstract::AbstractAttribute;
  37. using AnalysisEngine = abstract::AnalysisEngine;
  38. using AnalysisEnginePtr = abstract::AnalysisEnginePtr;
  39. class TestComposite : public UT::Common {
  40. public:
  41. virtual void SetUp();
  42. virtual void TearDown();
  43. AnalysisEnginePtr engine_;
  44. };
  45. void TestComposite::SetUp() {
  46. // init resource
  47. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager();
  48. engine_ = std::make_shared<AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), graph_manager);
  49. }
  50. void TestComposite::TearDown() {
  51. // destroy resource
  52. }
  53. class UTCompositeUtils {
  54. public:
  55. static AbstractTensorPtr ArrayInt32Of(std::initializer_list<int64_t> shp) {
  56. auto ele = std::make_shared<AbstractScalar>(kAnyValue, kInt64);
  57. return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
  58. }
  59. static FuncGraphPtr MakeFuncGraph(const MetaFuncGraphPtr &metaFuncGraphPtr, size_t nparam) {
  60. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  61. std::vector<AnfNodePtr> inputs;
  62. inputs.push_back(NewValueNode(metaFuncGraphPtr));
  63. for (size_t i = 0; i < nparam; i++) {
  64. inputs.push_back(func_graph->add_parameter());
  65. }
  66. CNodePtr cnode_prim = func_graph->NewCNode(inputs);
  67. inputs.clear();
  68. inputs.push_back(NewValueNode(prim::kPrimReturn));
  69. inputs.push_back(cnode_prim);
  70. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  71. func_graph->set_return(cnode_return);
  72. return func_graph;
  73. }
  74. };
  75. TEST_F(TestComposite, test_TupleSlice_arg_two_numbers) {
  76. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  77. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 3);
  78. AbstractBasePtrList eles;
  79. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  80. size_t tuple_size = 6;
  81. for (size_t i = 0; i < tuple_size; i++) {
  82. eles.push_back(tensor);
  83. }
  84. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  85. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  86. auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(5));
  87. AbstractBasePtrList args_spec_list = {tuple_tensor, start_index, stop_index};
  88. try {
  89. engine_->Run(tupleSliceGraphPtr, args_spec_list);
  90. FAIL() << "Excepted exception :Args type is wrong";
  91. } catch (std::runtime_error const &err) {
  92. ASSERT_TRUE(std::string(err.what()).find("For 'TupleSlice', the number of input should be 2, but got 3") !=
  93. std::string::npos);
  94. } catch (...) {
  95. FAIL() << "Excepted exception :Args type is wrong";
  96. }
  97. }
  98. TEST_F(TestComposite, test_TupleSlice_arg_one_number) {
  99. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  100. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  101. AbstractBasePtrList eles;
  102. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  103. size_t tuple_size = 6;
  104. for (size_t i = 0; i < tuple_size; i++) {
  105. eles.push_back(tensor);
  106. }
  107. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  108. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  109. AbstractBasePtrList args_spec_list = {tuple_tensor, start_index};
  110. try {
  111. trace::ClearTraceStack();
  112. engine_->Run(tupleSliceGraphPtr, args_spec_list);
  113. FAIL() << "Excepted exception: Args type is wrong";
  114. } catch (pybind11::type_error const &err) {
  115. ASSERT_TRUE(true);
  116. } catch (std::runtime_error const &err) {
  117. if (std::strstr(err.what(), "TypeError") != nullptr) {
  118. ASSERT_TRUE(true);
  119. } else {
  120. FAIL() << "Excepted exception: Args type is wrong, message: " << err.what();
  121. }
  122. } catch (...) {
  123. FAIL() << "Excepted exception: Args type is wrong";
  124. }
  125. }
  126. TEST_F(TestComposite, test_TupleSlice_arg_slice) {
  127. std::shared_ptr<py::scoped_interpreter> env = parse::python_adapter::set_python_scoped();
  128. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  129. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  130. AbstractBasePtrList eles;
  131. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  132. size_t tuple_size = 6;
  133. for (size_t i = 0; i < tuple_size; i++) {
  134. eles.push_back(tensor);
  135. }
  136. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  137. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  138. auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(6));
  139. auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(2));
  140. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  141. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  142. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
  143. if (ret == nullptr) {
  144. FAIL() << "Cast ret to abstract tuple failed.";
  145. }
  146. size_t real = ret->size();
  147. size_t expect = 3;
  148. ASSERT_EQ(real, expect);
  149. }
  150. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
  151. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  152. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  153. AbstractBasePtrList eles;
  154. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  155. size_t tuple_size = 6;
  156. for (size_t i = 0; i < tuple_size; i++) {
  157. eles.push_back(tensor);
  158. }
  159. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  160. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  161. auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(5));
  162. auto step = std::make_shared<AbstractNone>();
  163. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  164. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  165. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
  166. if (ret == nullptr) {
  167. FAIL() << "Cast ret to abstract tuple failed.";
  168. }
  169. size_t real = ret->size();
  170. size_t expect = 4;
  171. ASSERT_EQ(real, expect);
  172. }
  173. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
  174. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  175. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  176. AbstractBasePtrList eles;
  177. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  178. size_t tuple_size = 6;
  179. for (size_t i = 0; i < tuple_size; i++) {
  180. eles.push_back(tensor);
  181. }
  182. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  183. auto start_index = std::make_shared<AbstractNone>();
  184. auto stop_index = std::make_shared<AbstractNone>();
  185. auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
  186. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  187. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  188. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
  189. if (ret == nullptr) {
  190. FAIL() << "Cast ret to abstract tuple failed.";
  191. }
  192. size_t real = ret->size();
  193. size_t expect = 6;
  194. ASSERT_EQ(real, expect);
  195. }
  196. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
  197. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  198. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  199. AbstractBasePtrList eles;
  200. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  201. size_t tuple_size = 6;
  202. for (size_t i = 0; i < tuple_size; i++) {
  203. eles.push_back(tensor);
  204. }
  205. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  206. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(-2));
  207. auto stop_index = std::make_shared<AbstractNone>();
  208. auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
  209. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  210. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  211. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
  212. if (ret == nullptr) {
  213. FAIL() << "Cast ret to abstract tuple failed.";
  214. }
  215. size_t real = ret->size();
  216. size_t expect = 5;
  217. ASSERT_EQ(real, expect);
  218. }
  219. TEST_F(TestComposite, test_UnpackCall_3args) {
  220. MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
  221. FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3);
  222. auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
  223. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  224. AbstractBasePtrList eles;
  225. for (size_t i = 0; i < 6; i++) {
  226. eles.push_back(tensor);
  227. }
  228. AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
  229. AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  230. AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  231. AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  232. std::vector<AbstractAttribute> tensor_map{{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
  233. abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
  234. AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict};
  235. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract());
  236. if (ret == nullptr) {
  237. FAIL() << "Cast ret to abstract tuple failed.";
  238. }
  239. size_t real = ret->size();
  240. size_t expect = 9;
  241. ASSERT_EQ(real, expect);
  242. }
  243. TEST_F(TestComposite, test_UnpackCall_5args) {
  244. MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
  245. FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 5);
  246. auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
  247. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  248. AbstractBasePtrList eles;
  249. for (size_t i = 0; i < 6; i++) {
  250. eles.push_back(tensor);
  251. }
  252. AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
  253. AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  254. AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  255. AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  256. std::vector<AbstractAttribute> tensor_map{{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
  257. abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
  258. AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple};
  259. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract());
  260. if (ret == nullptr) {
  261. FAIL() << "Cast ret to abstract tuple failed.";
  262. }
  263. size_t real = ret->size();
  264. size_t expect = 18;
  265. ASSERT_EQ(real, expect);
  266. }
  267. TEST_F(TestComposite, test_ZipOperation) {
  268. MetaFuncGraphPtr zip_op = std::make_shared<prim::ZipOperation>("zip_op");
  269. FuncGraphPtr zip_op_graph = UTCompositeUtils::MakeFuncGraph(zip_op, 1);
  270. AbstractBasePtrList eles;
  271. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  272. size_t tuple_size = 3;
  273. for (size_t i = 0; i < tuple_size; i++) {
  274. eles.push_back(tensor);
  275. }
  276. auto tuple = std::make_shared<AbstractTuple>(eles);
  277. AbstractBasePtrList args_spec_list = {tuple};
  278. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred->abstract());
  279. if (ret == nullptr) {
  280. FAIL() << "Cast ret to abstract tuple failed.";
  281. }
  282. size_t real = ret->size();
  283. size_t expect = 3;
  284. ASSERT_EQ(real, expect);
  285. }
  286. } // namespace mindspore