You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

clean_test.cc 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <string>
  18. #include "common/common_test.h"
  19. #include "common/py_func_graph_fetcher.h"
  20. #include "utils/log_adapter.h"
  21. #include "pipeline/jit/parse/parse.h"
  22. #include "debug/draw.h"
  23. #include "frontend/optimizer/clean.h"
  24. namespace mindspore {
  25. namespace opt {
  26. using mindspore::abstract::AbstractAttribute;
  27. using mindspore::abstract::AbstractClass;
  28. using mindspore::abstract::AbstractError;
  29. using mindspore::abstract::AbstractList;
  30. using mindspore::abstract::AbstractScalar;
  31. using mindspore::abstract::AbstractTensor;
  32. using mindspore::abstract::AbstractTuple;
  33. class TestClean : public UT::Common {
  34. public:
  35. TestClean() : getPyFun("gtest_input.optimizer.clean_test", true) {}
  36. virtual void SetUp();
  37. virtual void TearDown();
  38. public:
  39. UT::PyFuncGraphFetcher getPyFun;
  40. FuncGraphPtr me_graph;
  41. };
  42. void TestClean::SetUp() {
  43. // build the func_graph.
  44. me_graph = std::make_shared<FuncGraph>();
  45. me_graph->debug_info()->set_name("next");
  46. // build the nodes
  47. AnfNodePtr valuenode_next = NewValueNode(std::string("ms_next"));
  48. ParameterPtr parameter = std::make_shared<Parameter>(me_graph);
  49. AbstractBasePtr para_scalar = std::make_shared<AbstractScalar>(static_cast<int64_t>(0));
  50. AbstractBasePtr para_list = std::make_shared<AbstractList>(
  51. AbstractBasePtrList({std::make_shared<AbstractScalar>(kFloat64), std::make_shared<AbstractScalar>(kFloat64)}));
  52. AbstractBasePtrList para_elem{para_scalar, para_list};
  53. AbstractBasePtr para_tuple = std::make_shared<AbstractTuple>(para_elem);
  54. parameter->set_abstract(para_tuple);
  55. AbstractBasePtr app_float = std::make_shared<AbstractScalar>(kFloat64);
  56. AbstractBasePtr app_int = std::make_shared<AbstractScalar>(kFloat64);
  57. AbstractBasePtr app_list = std::make_shared<AbstractList>(
  58. AbstractBasePtrList({std::make_shared<AbstractScalar>(kFloat64), std::make_shared<AbstractScalar>(kFloat64)}));
  59. AbstractBasePtr app_tuple_inner = std::make_shared<AbstractTuple>(AbstractBasePtrList{app_int, app_list});
  60. AbstractBasePtr app_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList{app_float, app_tuple_inner});
  61. AnfNodePtr cnode_57 = me_graph->NewCNode({valuenode_next, parameter});
  62. cnode_57->set_abstract(app_tuple);
  63. AnfNodePtr cnode_67 = me_graph->NewCNode({NewValueNode(prim::kPrimPartial), valuenode_next, parameter});
  64. cnode_67->set_abstract(app_tuple);
  65. AnfNodePtr cnode_66 = me_graph->NewCNode({NewValueNode(prim::kPrimScalarAdd), cnode_57, cnode_67});
  66. cnode_66->set_abstract(app_float);
  67. AnfNodePtr valuenode_return = NewValueNode(prim::kPrimReturn);
  68. CNodePtr cnode_55 = me_graph->NewCNode({valuenode_return, cnode_66});
  69. cnode_55->set_abstract(app_tuple);
  70. me_graph->set_output(cnode_66);
  71. me_graph->set_return(cnode_55);
  72. me_graph->add_parameter(parameter);
  73. }
  74. void TestClean::TearDown() {}
  75. TEST_F(TestClean, TestEraseClassGetAttr) {
  76. FuncGraphPtr func_graph;
  77. func_graph = getPyFun("test_erase_class_fn");
  78. ASSERT_TRUE(nullptr != func_graph);
  79. // save the func_graph to manager
  80. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  81. int dataclass_count = 0;
  82. for (auto node : manager->all_nodes()) {
  83. if (IsValueNode<parse::ClassObject>(node)) {
  84. dataclass_count++;
  85. }
  86. if (!node->isa<CNode>()) {
  87. continue;
  88. }
  89. auto input0 = node->cast<CNodePtr>()->input(0);
  90. if (IsValueNode<parse::ClassObject>(input0)) {
  91. std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kFloat64)},
  92. {"y", std::make_shared<AbstractScalar>(kFloat64)}};
  93. mindspore::HashMap<std::string, ValuePtr> methods;
  94. AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods);
  95. node->set_abstract(abs_ptr);
  96. }
  97. }
  98. ASSERT_EQ(dataclass_count, 1);
  99. SimplifyDataStructures(func_graph, manager);
  100. int tuple_getitem_count = 0;
  101. for (auto node : manager->all_nodes()) {
  102. if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
  103. tuple_getitem_count++;
  104. }
  105. }
  106. ASSERT_EQ(dataclass_count, 1);
  107. ASSERT_EQ(tuple_getitem_count, 2);
  108. }
  109. TEST_F(TestClean, TestEraseClassMakeRecord) {
  110. // build the graph
  111. auto func_graph = std::make_shared<FuncGraph>();
  112. func_graph->debug_info()->set_name("test_make_record");
  113. auto cons_make_record = NewValueNode(prim::kPrimMakeRecord);
  114. auto para1 = std::make_shared<Parameter>(func_graph);
  115. auto para2 = std::make_shared<Parameter>(func_graph);
  116. para1->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64));
  117. para2->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64));
  118. std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kAnyValue, kInt64)},
  119. {"y", std::make_shared<AbstractScalar>(kAnyValue, kInt64)}};
  120. mindspore::HashMap<std::string, ValuePtr> methods;
  121. AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods);
  122. auto cons_class = NewValueNode(abs_ptr->BuildValue());
  123. cons_class->set_abstract(abs_ptr);
  124. std::vector<AnfNodePtr> inputs{cons_make_record, cons_class, para1, para2};
  125. auto apply22 = func_graph->NewCNode(inputs);
  126. auto cons_return = NewValueNode(prim::kPrimReturn);
  127. auto apply11 = func_graph->NewCNode({cons_return, apply22});
  128. apply11->set_abstract(abs_ptr);
  129. func_graph->set_output(apply22);
  130. func_graph->set_return(apply11);
  131. func_graph->add_parameter(para1);
  132. func_graph->add_parameter(para2);
  133. auto manager = Manage(func_graph);
  134. SimplifyDataStructures(func_graph, manager);
  135. }
  136. TEST_F(TestClean, TestEraseClassPartial) {
  137. // build the graph
  138. auto func_graph = std::make_shared<FuncGraph>();
  139. func_graph->debug_info()->set_name("test_partial");
  140. auto cons_partial = NewValueNode(prim::kPrimPartial);
  141. auto para1 = std::make_shared<Parameter>(func_graph);
  142. para1->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64));
  143. auto cons_make_record = NewValueNode(prim::kPrimMakeRecord);
  144. std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kAnyValue, kInt64)},
  145. {"y", std::make_shared<AbstractScalar>(kAnyValue, kInt64)}};
  146. mindspore::HashMap<std::string, ValuePtr> methods;
  147. AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods);
  148. auto cons_class = NewValueNode(abs_ptr->BuildValue());
  149. cons_class->set_abstract(abs_ptr);
  150. std::vector<AnfNodePtr> inputs{cons_partial, cons_make_record, cons_class, para1};
  151. auto apply22 = func_graph->NewCNode(inputs);
  152. std::vector<AnfNodePtr> inputs_nopara{cons_partial, cons_make_record, cons_class};
  153. auto apply33 = func_graph->NewCNode(inputs_nopara);
  154. auto apply11 = func_graph->NewCNode({NewValueNode(prim::kPrimScalarAdd), apply22, apply33});
  155. auto cons_return = NewValueNode(prim::kPrimReturn);
  156. auto apply00 = func_graph->NewCNode({cons_return, apply11});
  157. apply00->set_abstract(abs_ptr);
  158. func_graph->set_output(apply22);
  159. func_graph->set_return(apply11);
  160. func_graph->add_parameter(para1);
  161. auto manager = Manage(func_graph);
  162. SimplifyDataStructures(func_graph, manager);
  163. }
  164. } // namespace opt
  165. } // namespace mindspore