You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

composite_test.cc 13 kB


  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <memory>
  17. #include "common/common_test.h"
  18. #include "ir/anf.h"
  19. #include "ir/value.h"
  20. #include "frontend/operator/composite/composite.h"
  21. #include "frontend/operator/ops.h"
  22. #include "pipeline/jit/static_analysis/prim.h"
  23. #include "abstract/abstract_function.h"
  24. #include "debug/trace.h"
  25. namespace mindspore {
  26. using Shape = abstract::Shape;
  27. using AbstractScalar = abstract::AbstractScalar;
  28. using AbstractScalarPtr = abstract::AbstractScalarPtr;
  29. using AbstractSlice = abstract::AbstractSlice;
  30. using AbstractSlicePtr = abstract::AbstractSlicePtr;
  31. using AbstractTuple = abstract::AbstractTuple;
  32. using AbstractTuplePtr = abstract::AbstractTuplePtr;
  33. using AbstractTensor = abstract::AbstractTensor;
  34. using AbstractTensorPtr = abstract::AbstractTensorPtr;
  35. using AbstractNone = abstract::AbstractNone;
  36. using AbstractAttribute = abstract::AbstractAttribute;
  37. using AnalysisEngine = abstract::AnalysisEngine;
  38. using AnalysisEnginePtr = abstract::AnalysisEnginePtr;
  39. class TestComposite : public UT::Common {
  40. public:
  41. virtual void SetUp();
  42. virtual void TearDown();
  43. AnalysisEnginePtr engine_;
  44. };
  45. void TestComposite::SetUp() {
  46. // init resource
  47. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager();
  48. engine_ = std::make_shared<AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), graph_manager);
  49. }
  50. void TestComposite::TearDown() {
  51. // destroy resource
  52. }
  53. class UTCompositeUtils {
  54. public:
  55. static AbstractTensorPtr ArrayInt32Of(std::initializer_list<int64_t> shp) {
  56. auto ele = std::make_shared<AbstractScalar>(kAnyValue, kInt64);
  57. return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
  58. }
  59. static FuncGraphPtr MakeFuncGraph(const MetaFuncGraphPtr &metaFuncGraphPtr, size_t nparam) {
  60. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  61. std::vector<AnfNodePtr> inputs;
  62. inputs.push_back(NewValueNode(metaFuncGraphPtr));
  63. for (size_t i = 0; i < nparam; i++) {
  64. inputs.push_back(func_graph->add_parameter());
  65. }
  66. CNodePtr cnode_prim = func_graph->NewCNode(inputs);
  67. inputs.clear();
  68. inputs.push_back(NewValueNode(prim::kPrimReturn));
  69. inputs.push_back(cnode_prim);
  70. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  71. func_graph->set_return(cnode_return);
  72. return func_graph;
  73. }
  74. };
  75. TEST_F(TestComposite, test_TupleSlice_arg_two_numbers) {
  76. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  77. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 3);
  78. AbstractBasePtrList eles;
  79. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  80. size_t tuple_size = 6;
  81. for (size_t i = 0; i < tuple_size; i++) {
  82. eles.push_back(tensor);
  83. }
  84. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  85. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  86. auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(5));
  87. AbstractBasePtrList args_spec_list = {tuple_tensor, start_index, stop_index};
  88. try {
  89. engine_->Run(tupleSliceGraphPtr, args_spec_list);
  90. FAIL() << "Excepted exception :Args type is wrong";
  91. } catch (std::runtime_error const &err) {
  92. ASSERT_TRUE(std::string(err.what()).find("TupleSlice input args size should be 2, but got 3") != std::string::npos);
  93. } catch (...) {
  94. FAIL() << "Excepted exception :Args type is wrong";
  95. }
  96. }
  97. TEST_F(TestComposite, test_TupleSlice_arg_one_number) {
  98. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  99. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  100. AbstractBasePtrList eles;
  101. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  102. size_t tuple_size = 6;
  103. for (size_t i = 0; i < tuple_size; i++) {
  104. eles.push_back(tensor);
  105. }
  106. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  107. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  108. AbstractBasePtrList args_spec_list = {tuple_tensor, start_index};
  109. try {
  110. trace::ClearTraceStack();
  111. engine_->Run(tupleSliceGraphPtr, args_spec_list);
  112. FAIL() << "Excepted exception: Args type is wrong";
  113. } catch (pybind11::type_error const &err) {
  114. ASSERT_TRUE(true);
  115. } catch (std::runtime_error const &err) {
  116. if (std::strstr(err.what(), "TypeError") != nullptr) {
  117. ASSERT_TRUE(true);
  118. } else {
  119. FAIL() << "Excepted exception: Args type is wrong, message: " << err.what();
  120. }
  121. } catch (...) {
  122. FAIL() << "Excepted exception: Args type is wrong";
  123. }
  124. }
  125. TEST_F(TestComposite, test_TupleSlice_arg_slice) {
  126. std::shared_ptr<py::scoped_interpreter> env = parse::python_adapter::set_python_scoped();
  127. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  128. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  129. AbstractBasePtrList eles;
  130. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  131. size_t tuple_size = 6;
  132. for (size_t i = 0; i < tuple_size; i++) {
  133. eles.push_back(tensor);
  134. }
  135. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  136. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  137. auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(6));
  138. auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(2));
  139. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  140. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  141. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
  142. if (ret == nullptr) {
  143. FAIL() << "Cast ret to abstract tuple failed.";
  144. }
  145. size_t real = ret->size();
  146. size_t expect = 3;
  147. ASSERT_EQ(real, expect);
  148. }
  149. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
  150. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  151. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  152. AbstractBasePtrList eles;
  153. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  154. size_t tuple_size = 6;
  155. for (size_t i = 0; i < tuple_size; i++) {
  156. eles.push_back(tensor);
  157. }
  158. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  159. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  160. auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(5));
  161. auto step = std::make_shared<AbstractNone>();
  162. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  163. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  164. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
  165. if (ret == nullptr) {
  166. FAIL() << "Cast ret to abstract tuple failed.";
  167. }
  168. size_t real = ret->size();
  169. size_t expect = 4;
  170. ASSERT_EQ(real, expect);
  171. }
  172. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
  173. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  174. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  175. AbstractBasePtrList eles;
  176. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  177. size_t tuple_size = 6;
  178. for (size_t i = 0; i < tuple_size; i++) {
  179. eles.push_back(tensor);
  180. }
  181. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  182. auto start_index = std::make_shared<AbstractNone>();
  183. auto stop_index = std::make_shared<AbstractNone>();
  184. auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
  185. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  186. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  187. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
  188. if (ret == nullptr) {
  189. FAIL() << "Cast ret to abstract tuple failed.";
  190. }
  191. size_t real = ret->size();
  192. size_t expect = 6;
  193. ASSERT_EQ(real, expect);
  194. }
  195. TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
  196. MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
  197. FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
  198. AbstractBasePtrList eles;
  199. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  200. size_t tuple_size = 6;
  201. for (size_t i = 0; i < tuple_size; i++) {
  202. eles.push_back(tensor);
  203. }
  204. auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
  205. auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(-2));
  206. auto stop_index = std::make_shared<AbstractNone>();
  207. auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
  208. auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
  209. AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
  210. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
  211. if (ret == nullptr) {
  212. FAIL() << "Cast ret to abstract tuple failed.";
  213. }
  214. size_t real = ret->size();
  215. size_t expect = 5;
  216. ASSERT_EQ(real, expect);
  217. }
  218. TEST_F(TestComposite, test_UnpackCall_3args) {
  219. MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
  220. FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3);
  221. auto fn_arg= std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
  222. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  223. AbstractBasePtrList eles;
  224. for (size_t i = 0; i < 6; i++) {
  225. eles.push_back(tensor);
  226. }
  227. AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
  228. AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  229. AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  230. AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  231. std::vector<AbstractAttribute> tensor_map{{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
  232. abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
  233. AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict};
  234. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract());
  235. if (ret == nullptr) {
  236. FAIL() << "Cast ret to abstract tuple failed.";
  237. }
  238. size_t real = ret->size();
  239. size_t expect = 9;
  240. ASSERT_EQ(real, expect);
  241. }
  242. TEST_F(TestComposite, test_UnpackCall_5args) {
  243. MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
  244. FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 5);
  245. auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
  246. AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  247. AbstractBasePtrList eles;
  248. for (size_t i = 0; i < 6; i++) {
  249. eles.push_back(tensor);
  250. }
  251. AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
  252. AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  253. AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  254. AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  255. std::vector<AbstractAttribute> tensor_map{{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
  256. abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
  257. AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple};
  258. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract());
  259. if (ret == nullptr) {
  260. FAIL() << "Cast ret to abstract tuple failed.";
  261. }
  262. size_t real = ret->size();
  263. size_t expect = 18;
  264. ASSERT_EQ(real, expect);
  265. }
  266. TEST_F(TestComposite, test_ZipOperation) {
  267. MetaFuncGraphPtr zip_op = std::make_shared<prim::ZipOperation>("zip_op");
  268. FuncGraphPtr zip_op_graph = UTCompositeUtils::MakeFuncGraph(zip_op, 1);
  269. AbstractBasePtrList eles;
  270. auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
  271. size_t tuple_size = 3;
  272. for (size_t i = 0; i < tuple_size; i++) {
  273. eles.push_back(tensor);
  274. }
  275. auto tuple = std::make_shared<AbstractTuple>(eles);
  276. AbstractBasePtrList args_spec_list = {tuple};
  277. AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred->abstract());
  278. if (ret == nullptr) {
  279. FAIL() << "Cast ret to abstract tuple failed.";
  280. }
  281. size_t real = ret->size();
  282. size_t expect = 3;
  283. ASSERT_EQ(real, expect);
  284. }
  285. } // namespace mindspore