You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

composite_test.cc 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <memory>
  17. #include "common/common_test.h"
  18. #include "ir/anf.h"
  19. #include "ir/value.h"
  20. #include "operator/composite/composite.h"
  21. #include "operator/ops.h"
  22. #include "pipeline/static_analysis/prim.h"
  23. #include "pipeline/static_analysis/abstract_function.h"
  24. #include "debug/trace.h"
  25. namespace mindspore {
  26. using Shape = abstract::Shape;
  27. using AbstractScalar = abstract::AbstractScalar;
  28. using AbstractScalarPtr = abstract::AbstractScalarPtr;
  29. using AbstractSlice = abstract::AbstractSlice;
  30. using AbstractSlicePtr = abstract::AbstractSlicePtr;
  31. using AbstractTuple = abstract::AbstractTuple;
  32. using AbstractTuplePtr = abstract::AbstractTuplePtr;
  33. using AbstractTensor = abstract::AbstractTensor;
  34. using AbstractTensorPtr = abstract::AbstractTensorPtr;
  35. using AbstractNone = abstract::AbstractNone;
  36. using AbstractAttribute = abstract::AbstractAttribute;
  37. using AnalysisEngine = abstract::AnalysisEngine;
  38. using AnalysisEnginePtr = abstract::AnalysisEnginePtr;
  39. class TestComposite : public UT::Common {
  40. public:
  41. virtual void SetUp();
  42. virtual void TearDown();
  43. AnalysisEnginePtr engine_;
  44. };
  45. void TestComposite::SetUp() {
  46. // init resource
  47. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager();
  48. engine_ = std::make_shared<AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), graph_manager);
  49. }
  50. void TestComposite::TearDown() {
  51. // destroy resource
  52. }
  53. class UTCompositeUtils {
  54. public:
  55. static AbstractTensorPtr ArrayInt32Of(std::initializer_list<int> shp) {
  56. auto ele = std::make_shared<AbstractScalar>(kAnyValue, kInt32);
  57. return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
  58. }
  59. static FuncGraphPtr MakeFuncGraph(const MetaFuncGraphPtr &metaFuncGraphPtr, size_t nparam) {
  60. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  61. std::vector<AnfNodePtr> inputs;
  62. inputs.push_back(NewValueNode(metaFuncGraphPtr));
  63. for (size_t i = 0; i < nparam; i++) {
  64. inputs.push_back(func_graph->add_parameter());
  65. }
  66. CNodePtr cnode_prim = func_graph->NewCNode(inputs);
  67. inputs.clear();
  68. inputs.push_back(NewValueNode(prim::kPrimReturn));
  69. inputs.push_back(cnode_prim);
  70. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  71. func_graph->set_return(cnode_return);
  72. return func_graph;
  73. }
  74. };
  75. TEST_F(TestComposite, test_TupleSlice_arg_two_numbers) {
  76. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  77. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 3);
  78. AbstractBasePtrList eles;
  79. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  80. size_t tuple_size = 6;
  81. for (size_t i = 0; i < tuple_size; i++) {
  82. eles.push_back(tensor);
  83. }
  84. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  85. auto start_index = std::make_shared<AbstractScalar>(1);
  86. auto stop_index = std::make_shared<AbstractScalar>(5);
  87. AbstractBasePtrList args_spec_list = {tuple_tensor, start_index, stop_index};
  88. try {
  89. engine_->Run(tupleSliceGraphPtr, args_spec_list);
  90. FAIL() << "Excepted exception :Args type is wrong";
  91. } catch (std::runtime_error const &err) {
  92. ASSERT_TRUE(std::string(err.what()).find("TupleSlice input args size should be 2, but got 3") != std::string::npos);
  93. } catch (...) {
  94. FAIL() << "Excepted exception :Args type is wrong";
  95. }
  96. }
  97. TEST_F(TestComposite, test_TupleSlice_arg_one_number) {
  98. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  99. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  100. AbstractBasePtrList eles;
  101. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  102. size_t tuple_size = 6;
  103. for (size_t i = 0; i < tuple_size; i++) {
  104. eles.push_back(tensor);
  105. }
  106. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  107. auto start_index = std::make_shared<AbstractScalar>(1);
  108. AbstractBasePtrList args_spec_list = {tuple_tensor, start_index};
  109. try {
  110. trace::ClearTraceStack();
  111. engine_->Run(tupleSliceGraphPtr, args_spec_list);
  112. FAIL() << "Excepted exception :Args type is wrong";
  113. } catch (pybind11::type_error const &err) {
  114. ASSERT_TRUE(true);
  115. } catch (...) {
  116. FAIL() << "Excepted exception :Args type is wrong";
  117. }
  118. }
  119. TEST_F(TestComposite, test_TupleSlice_arg_slice) {
  120. std::shared_ptr<py::scoped_interpreter> env = parse::python_adapter::set_python_scoped();
  121. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  122. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  123. AbstractBasePtrList eles;
  124. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  125. size_t tuple_size = 6;
  126. for (size_t i = 0; i < tuple_size; i++) {
  127. eles.push_back(tensor);
  128. }
  129. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  130. auto start_index = std::make_shared<AbstractScalar>(1);
  131. auto stop_index = std::make_shared<AbstractScalar>(6);
  132. auto step = std::make_shared<AbstractScalar>(2);
  133. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  134. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  135. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
  136. if (ret == nullptr) {
  137. FAIL() << "Cast ret to abstract tuple failed.";
  138. }
  139. size_t real = ret->size();
  140. size_t expect = 3;
  141. ASSERT_EQ(real, expect);
  142. }
  143. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
  144. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  145. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  146. AbstractBasePtrList eles;
  147. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  148. size_t tuple_size = 6;
  149. for (size_t i = 0; i < tuple_size; i++) {
  150. eles.push_back(tensor);
  151. }
  152. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  153. auto start_index = std::make_shared<AbstractScalar>(1);
  154. auto stop_index = std::make_shared<AbstractScalar>(5);
  155. auto step = std::make_shared<AbstractNone>();
  156. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  157. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  158. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
  159. if (ret == nullptr) {
  160. FAIL() << "Cast ret to abstract tuple failed.";
  161. }
  162. size_t real = ret->size();
  163. size_t expect = 4;
  164. ASSERT_EQ(real, expect);
  165. }
  166. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
  167. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  168. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  169. AbstractBasePtrList eles;
  170. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  171. size_t tuple_size = 6;
  172. for (size_t i = 0; i < tuple_size; i++) {
  173. eles.push_back(tensor);
  174. }
  175. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  176. auto start_index = std::make_shared<AbstractNone>();
  177. auto stop_index = std::make_shared<AbstractNone>();
  178. auto step = std::make_shared<AbstractScalar>(-1);
  179. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  180. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  181. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
  182. if (ret == nullptr) {
  183. FAIL() << "Cast ret to abstract tuple failed.";
  184. }
  185. size_t real = ret->size();
  186. size_t expect = 6;
  187. ASSERT_EQ(real, expect);
  188. }
  189. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
  190. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  191. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  192. AbstractBasePtrList eles;
  193. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  194. size_t tuple_size = 6;
  195. for (size_t i = 0; i < tuple_size; i++) {
  196. eles.push_back(tensor);
  197. }
  198. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  199. auto start_index = std::make_shared<AbstractScalar>(-2);
  200. auto stop_index = std::make_shared<AbstractNone>();
  201. auto step = std::make_shared<AbstractScalar>(-1);
  202. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  203. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  204. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
  205. if (ret == nullptr) {
  206. FAIL() << "Cast ret to abstract tuple failed.";
  207. }
  208. size_t real = ret->size();
  209. size_t expect = 5;
  210. ASSERT_EQ(real, expect);
  211. }
  212. TEST_F(TestComposite, test_UnpackCall_3args) {
  213. MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
  214. FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3);
  215. auto fn_arg= std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
  216. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  217. AbstractBasePtrList eles;
  218. for (size_t i = 0; i < 6; i++) {
  219. eles.push_back(tensor);
  220. }
  221. AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
  222. AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  223. AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  224. AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  225. std::vector<AbstractAttribute> tensor_map{{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
  226. abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
  227. AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict};
  228. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract());
  229. if (ret == nullptr) {
  230. FAIL() << "Cast ret to abstract tuple failed.";
  231. }
  232. size_t real = ret->size();
  233. size_t expect = 9;
  234. ASSERT_EQ(real, expect);
  235. }
  236. TEST_F(TestComposite, test_UnpackCall_5args) {
  237. MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
  238. FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 5);
  239. auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
  240. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  241. AbstractBasePtrList eles;
  242. for (size_t i = 0; i < 6; i++) {
  243. eles.push_back(tensor);
  244. }
  245. AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
  246. AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  247. AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  248. AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  249. std::vector<AbstractAttribute> tensor_map{{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
  250. abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
  251. AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple};
  252. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract());
  253. if (ret == nullptr) {
  254. FAIL() << "Cast ret to abstract tuple failed.";
  255. }
  256. size_t real = ret->size();
  257. size_t expect = 18;
  258. ASSERT_EQ(real, expect);
  259. }
  260. TEST_F(TestComposite, test_ZipOperation) {
  261. MetaFuncGraphPtr zip_op = std::make_shared<prim::ZipOperation>("zip_op");
  262. FuncGraphPtr zip_op_graph = UTCompositeUtils::MakeFuncGraph(zip_op, 1);
  263. AbstractBasePtrList eles;
  264. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  265. size_t tuple_size = 3;
  266. for (size_t i = 0; i < tuple_size; i++) {
  267. eles.push_back(tensor);
  268. }
  269. auto tuple = std::make_shared<AbstractTuple>(eles);
  270. AbstractBasePtrList args_spec_list = {tuple};
  271. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred->abstract());
  272. if (ret == nullptr) {
  273. FAIL() << "Cast ret to abstract tuple failed.";
  274. }
  275. size_t real = ret->size();
  276. size_t expect = 3;
  277. ASSERT_EQ(real, expect);
  278. }
  279. } // namespace mindspore