You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

clone_test.cc 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <algorithm>
  17. #include "common/common_test.h"
  18. #include "common/py_func_graph_fetcher.h"
  19. #include "ir/manager.h"
  20. #include "utils/log_adapter.h"
  21. #include "ir/func_graph_cloner.h"
  22. #include "pipeline/jit/parse/parse.h"
  23. #include "ir/graph_utils.h"
  24. #include "debug/draw.h"
  25. namespace mindspore {
  26. class TestCloner : public UT::Common {
  27. public:
  28. TestCloner() : getPyFun("gtest_input.ir.clone_test", true) {
  29. one = NewValueNode(static_cast<int64_t>(1));
  30. two = NewValueNode(static_cast<int64_t>(2));
  31. three = NewValueNode(static_cast<int64_t>(3));
  32. }
  33. FuncGraphPtr GraphForInline() { return nullptr; }
  34. void SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig, const std::vector<AnfNodePtr>& params,
  35. FuncGraphPtr target);
  36. public:
  37. UT::PyFuncGraphFetcher getPyFun;
  38. ValueNodePtr one;
  39. ValueNodePtr two;
  40. ValueNodePtr three;
  41. };
  42. void TestCloner::SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig,
  43. const std::vector<AnfNodePtr>& params, FuncGraphPtr target) {
  44. auto g = (*cl)[orig];
  45. ASSERT_TRUE(g != target);
  46. ASSERT_TRUE(g == orig);
  47. auto new_root = (*cl)[orig->output()];
  48. ASSERT_TRUE(new_root != orig->output());
  49. AnfNodeSet orig_nodes = AnfNodeSet(DeepLinkedGraphSearch(orig->output()));
  50. AnfNodeSet new_nodes = AnfNodeSet(DeepLinkedGraphSearch(new_root));
  51. for (auto& p : params) {
  52. ASSERT_TRUE(new_nodes.contains(p));
  53. }
  54. for (auto& node : orig_nodes) {
  55. if (node->func_graph() == orig) {
  56. ASSERT_TRUE((*cl)[node]);
  57. }
  58. }
  59. ASSERT_TRUE(target->output() == three);
  60. }
  61. TEST_F(TestCloner, test_clone_simple) {
  62. std::string py_code = "test_clone_simple";
  63. FuncGraphPtr g = getPyFun.CallAndParseRet(py_code);
  64. ASSERT_TRUE(g != nullptr);
  65. std::vector<FuncGraphPtr> gs = {g};
  66. Cloner cl(gs, true);
  67. auto g2 = cl[g];
  68. AnfNodeSet d1 = AnfNodeSet(DeepScopedGraphSearch(g->get_return()));
  69. AnfNodeSet d2 = AnfNodeSet(DeepScopedGraphSearch(g2->get_return()));
  70. auto common = d1 & d2;
  71. ASSERT_EQ((size_t)0, common.size());
  72. Cloner cl2(gs);
  73. auto g3 = cl2[g];
  74. std::vector<Primitive> results = {Primitive("scalar_add"), Primitive("scalar_mul"), Primitive("return")};
  75. AnfNodeSet d3 = AnfNodeSet(DeepScopedGraphSearch(g3->get_return()));
  76. common = d1 & d3;
  77. for (auto& x : common) {
  78. ASSERT_TRUE(x->isa<ValueNode>());
  79. ASSERT_TRUE(find(results.begin(), results.end(), *x->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>()) !=
  80. results.end());
  81. }
  82. }
  83. TEST_F(TestCloner, test_clone_closure) {
  84. std::string py_code = "test_clone_closure";
  85. // parse ast to graph
  86. FuncGraphPtr parsed_f = getPyFun(py_code);
  87. FuncGraphIndex idx(parsed_f);
  88. auto g = idx.GetFirstFuncGraph("j");
  89. std::vector<FuncGraphPtr> gs = {g};
  90. Cloner cl(gs, true);
  91. auto g_clone = cl[g];
  92. draw::Draw("test_clone_closure_g_clone.dot", g_clone);
  93. FuncGraphIndex idx2(g_clone, DeepLinkedGraphSearch);
  94. std::string name_list = "xy";
  95. for (auto name : name_list) {
  96. ASSERT_EQ(idx.GetFirstNode(std::string(1, name)), idx2.GetFirstNode(std::string(1, name)));
  97. }
  98. ASSERT_FALSE(idx.GetFirstNode("z") == idx2.GetFirstNode("z"));
  99. ASSERT_FALSE(idx.GetFirstFuncGraph("j") == idx2.GetFirstFuncGraph("j"));
  100. }
  101. TEST_F(TestCloner, test_clone_lifting) {
  102. std::string py_code = "test_clone_closure";
  103. // parse ast to graph
  104. FuncGraphPtr parsed_f = getPyFun(py_code);
  105. draw::Draw("test_clone_before_lifting.dot", parsed_f);
  106. auto g_lifting = LiftingClone(parsed_f);
  107. draw::Draw("test_clone_after_lifting.dot", g_lifting);
  108. FuncGraphIndex idx(g_lifting);
  109. auto g = idx.GetFirstFuncGraph("j");
  110. auto params = g_lifting->parameters();
  111. auto child_params = g->parameters();
  112. ASSERT_TRUE(params.size() + 1 == child_params.size());
  113. }
  114. TEST_F(TestCloner, test_clone_scoping) {
  115. std::string py_code = "test_clone_scoping";
  116. // parse ast to graph
  117. FuncGraphPtr g = getPyFun.CallAndParseRet(py_code);
  118. std::vector<FuncGraphPtr> gs = {g};
  119. Cloner cl(gs, true);
  120. auto g2 = cl[g];
  121. FuncGraphIndex idx1(g);
  122. FuncGraphIndex idx2(g2);
  123. std::string name_list = "fgi";
  124. for (auto name : name_list) {
  125. auto result1 = idx1.GetFirstFuncGraph(std::string(1, name));
  126. auto result2 = idx2.GetFirstFuncGraph(std::string(1, name));
  127. ASSERT_FALSE(result1 == result2);
  128. }
  129. name_list = "h";
  130. for (auto name : name_list) {
  131. ASSERT_TRUE(idx1.GetFirstFuncGraph(std::string(1, name)) == idx2.GetFirstFuncGraph(std::string(1, name)));
  132. }
  133. }
  134. TEST_F(TestCloner, test_clone_total) {
  135. std::string py_code = "test_clone_total";
  136. // parse ast to graph
  137. getPyFun.SetDoResolve();
  138. FuncGraphPtr g = getPyFun.CallAndParseRet(py_code);
  139. if (g == nullptr) {
  140. return;
  141. }
  142. FuncGraphIndex idx0(g);
  143. std::vector<FuncGraphPtr> gs = {g};
  144. Cloner cl1(gs, true, true, true);
  145. auto g2 = cl1[g];
  146. FuncGraphIndex idx1(g2);
  147. ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total_sub") == idx1.GetFirstFuncGraph("clone_total_sub"));
  148. ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total") == idx1.GetFirstFuncGraph("clone_total"));
  149. Cloner cl2(gs, true);
  150. FuncGraphIndex idx2(cl2[g]);
  151. ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total") == idx2.GetFirstFuncGraph("clone_total"));
  152. ASSERT_TRUE(idx0.GetFirstFuncGraph("clone_total_sub") == idx2.GetFirstFuncGraph("clone_total_sub"));
  153. }
  154. } // namespace mindspore