You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

clean_test.cc 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <string>
  18. #include "common/common_test.h"
  19. #include "common/py_func_graph_fetcher.h"
  20. #include "utils/log_adapter.h"
  21. #include "pipeline/parse/parse.h"
  22. #include "debug/draw.h"
  23. #include "optimizer/clean.h"
  24. namespace mindspore {
  25. namespace opt {
  26. using mindspore::abstract::AbstractAttribute;
  27. using mindspore::abstract::AbstractClass;
  28. using mindspore::abstract::AbstractError;
  29. using mindspore::abstract::AbstractList;
  30. using mindspore::abstract::AbstractScalar;
  31. using mindspore::abstract::AbstractTensor;
  32. using mindspore::abstract::AbstractTuple;
  33. class TestClean : public UT::Common {
  34. public:
  35. TestClean() : getPyFun("gtest_input.optimizer.clean_test", true) {}
  36. virtual void SetUp();
  37. virtual void TearDown();
  38. public:
  39. UT::PyFuncGraphFetcher getPyFun;
  40. FuncGraphPtr me_graph;
  41. };
  42. void TestClean::SetUp() {
  43. // build the func_graph.
  44. me_graph = std::make_shared<FuncGraph>();
  45. me_graph->debug_info()->set_name("next");
  46. // build the nodes
  47. AnfNodePtr valuenode_next = NewValueNode(std::string("ms_next"));
  48. ParameterPtr parameter = std::make_shared<Parameter>(me_graph);
  49. AbstractBasePtr para_scalar = std::make_shared<AbstractScalar>(0);
  50. AbstractBasePtr para_list = std::make_shared<AbstractList>(
  51. AbstractBasePtrList({std::make_shared<AbstractScalar>(kFloat64), std::make_shared<AbstractScalar>(kFloat64)}));
  52. AbstractBasePtrList para_elem{para_scalar, para_list};
  53. AbstractBasePtr para_tuple = std::make_shared<AbstractTuple>(para_elem);
  54. parameter->set_abstract(para_tuple);
  55. AbstractBasePtr app_float = std::make_shared<AbstractScalar>(kFloat64);
  56. AbstractBasePtr app_int = std::make_shared<AbstractScalar>(kFloat64);
  57. AbstractBasePtr app_list = std::make_shared<AbstractList>(
  58. AbstractBasePtrList({std::make_shared<AbstractScalar>(kFloat64), std::make_shared<AbstractScalar>(kFloat64)}));
  59. AbstractBasePtr app_tuple_inner = std::make_shared<AbstractTuple>(AbstractBasePtrList{app_int, app_list});
  60. AbstractBasePtr app_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList{app_float, app_tuple_inner});
  61. AnfNodePtr cnode_57 = me_graph->NewCNode({valuenode_next, parameter});
  62. cnode_57->set_abstract(app_tuple);
  63. AnfNodePtr cnode_67 = me_graph->NewCNode({NewValueNode(prim::kPrimPartial), valuenode_next, parameter});
  64. cnode_67->set_abstract(app_tuple);
  65. AnfNodePtr cnode_66 = me_graph->NewCNode({NewValueNode(prim::kPrimScalarAdd), cnode_57, cnode_67});
  66. cnode_66->set_abstract(app_float);
  67. AnfNodePtr valuenode_return = NewValueNode(prim::kPrimReturn);
  68. CNodePtr cnode_55 = me_graph->NewCNode({valuenode_return, cnode_66});
  69. cnode_55->set_abstract(app_tuple);
  70. me_graph->set_output(cnode_66);
  71. me_graph->set_return(cnode_55);
  72. me_graph->add_parameter(parameter);
  73. }
  74. void TestClean::TearDown() {}
  75. TEST_F(TestClean, TestEraseClassGetAttr) {
  76. FuncGraphPtr func_graph;
  77. func_graph = getPyFun("test_erase_class_fn");
  78. ASSERT_TRUE(nullptr != func_graph);
  79. // save the func_graph to manager
  80. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  81. int dataclass_count = 0;
  82. for (auto node : manager->all_nodes()) {
  83. if (IsValueNode<parse::ClassObject>(node)) {
  84. dataclass_count++;
  85. }
  86. if (!node->isa<CNode>()) {
  87. continue;
  88. }
  89. auto input0 = node->cast<CNodePtr>()->input(0);
  90. if (IsValueNode<parse::ClassObject>(input0)) {
  91. std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kFloat64)},
  92. {"y", std::make_shared<AbstractScalar>(kFloat64)}};
  93. std::unordered_map<std::string, ValuePtr> methods;
  94. AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods);
  95. node->set_abstract(abs_ptr);
  96. }
  97. }
  98. ASSERT_EQ(dataclass_count, 1);
  99. // draw func_graph before erase class
  100. draw::Draw("opt_before_erase_class.dot", func_graph);
  101. SimplifyDataStructures(func_graph, manager);
  102. // draw func_graph after erase class
  103. draw::Draw("opt_after_erase_class.dot", func_graph);
  104. int tuple_getitem_count = 0;
  105. for (auto node : manager->all_nodes()) {
  106. if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
  107. tuple_getitem_count++;
  108. }
  109. }
  110. ASSERT_EQ(dataclass_count, 1);
  111. ASSERT_EQ(tuple_getitem_count, 2);
  112. }
  113. TEST_F(TestClean, TestEraseClassMakeRecord) {
  114. // build the graph
  115. auto func_graph = std::make_shared<FuncGraph>();
  116. func_graph->debug_info()->set_name("test_make_record");
  117. auto cons_make_record = NewValueNode(prim::kPrimMakeRecord);
  118. auto para1 = std::make_shared<Parameter>(func_graph);
  119. auto para2 = std::make_shared<Parameter>(func_graph);
  120. para1->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64));
  121. para2->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64));
  122. std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kAnyValue, kInt64)},
  123. {"y", std::make_shared<AbstractScalar>(kAnyValue, kInt64)}};
  124. std::unordered_map<std::string, ValuePtr> methods;
  125. AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods);
  126. auto cons_class = NewValueNode(abs_ptr->BuildValue());
  127. cons_class->set_abstract(abs_ptr);
  128. std::vector<AnfNodePtr> inputs{cons_make_record, cons_class, para1, para2};
  129. auto apply22 = func_graph->NewCNode(inputs);
  130. auto cons_return = NewValueNode(prim::kPrimReturn);
  131. auto apply11 = func_graph->NewCNode({cons_return, apply22});
  132. apply11->set_abstract(abs_ptr);
  133. func_graph->set_output(apply22);
  134. func_graph->set_return(apply11);
  135. func_graph->add_parameter(para1);
  136. func_graph->add_parameter(para2);
  137. auto manager = Manage(func_graph);
  138. draw::Draw("opt_erase_class_record_before.dot", func_graph);
  139. SimplifyDataStructures(func_graph, manager);
  140. draw::Draw("opt_erase_class_record_after.dot", func_graph);
  141. }
  142. TEST_F(TestClean, TestEraseClassPartial) {
  143. // build the graph
  144. auto func_graph = std::make_shared<FuncGraph>();
  145. func_graph->debug_info()->set_name("test_partial");
  146. auto cons_partial = NewValueNode(prim::kPrimPartial);
  147. auto para1 = std::make_shared<Parameter>(func_graph);
  148. para1->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64));
  149. auto cons_make_record = NewValueNode(prim::kPrimMakeRecord);
  150. std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kAnyValue, kInt64)},
  151. {"y", std::make_shared<AbstractScalar>(kAnyValue, kInt64)}};
  152. std::unordered_map<std::string, ValuePtr> methods;
  153. AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods);
  154. auto cons_class = NewValueNode(abs_ptr->BuildValue());
  155. cons_class->set_abstract(abs_ptr);
  156. std::vector<AnfNodePtr> inputs{cons_partial, cons_make_record, cons_class, para1};
  157. auto apply22 = func_graph->NewCNode(inputs);
  158. std::vector<AnfNodePtr> inputs_nopara{cons_partial, cons_make_record, cons_class};
  159. auto apply33 = func_graph->NewCNode(inputs_nopara);
  160. auto apply11 = func_graph->NewCNode({NewValueNode(prim::kPrimScalarAdd), apply22, apply33});
  161. auto cons_return = NewValueNode(prim::kPrimReturn);
  162. auto apply00 = func_graph->NewCNode({cons_return, apply11});
  163. apply00->set_abstract(abs_ptr);
  164. func_graph->set_output(apply22);
  165. func_graph->set_return(apply11);
  166. func_graph->add_parameter(para1);
  167. auto manager = Manage(func_graph);
  168. draw::Draw("opt_erase_class_partial_before.dot", func_graph);
  169. SimplifyDataStructures(func_graph, manager);
  170. draw::Draw("opt_erase_class_partial_after.dot", func_graph);
  171. }
  172. TEST_F(TestClean, TestEraseTuple) {
  173. ASSERT_TRUE(nullptr != me_graph);
  174. std::shared_ptr<FuncGraphManager> manager = Manage(me_graph);
  175. draw::Draw("opt_before_erase_tuple.dot", me_graph);
  176. int abstract_tuple_count = 0;
  177. for (auto node : manager->all_nodes()) {
  178. auto dt = node->abstract();
  179. if (dyn_cast<AbstractTuple>(dt) != nullptr) {
  180. abstract_tuple_count++;
  181. }
  182. }
  183. ASSERT_EQ(abstract_tuple_count, 4);
  184. // erase tuple in CNode57 and Parameter
  185. EraseTuple(me_graph, manager);
  186. abstract_tuple_count = 0;
  187. for (auto node : manager->all_nodes()) {
  188. auto dt = node->abstract();
  189. if (dyn_cast<AbstractTuple>(dt) != nullptr) {
  190. abstract_tuple_count++;
  191. }
  192. }
  193. ASSERT_EQ(abstract_tuple_count, 3);
  194. draw::Draw("opt_after_erase_tuple.dot", me_graph);
  195. }
  196. } // namespace opt
  197. } // namespace mindspore