You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

clone_test.cc 8.2 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <algorithm>
  17. #include "common/common_test.h"
  18. #include "common/py_func_graph_fetcher.h"
  19. #include "ir/manager.h"
  20. #include "utils/log_adapter.h"
  21. #include "ir/func_graph_cloner.h"
  22. #include "pipeline/jit/parse/parse.h"
  23. #include "ir/graph_utils.h"
  24. #include "debug/draw.h"
  25. #include "base/core_ops.h"
  26. namespace mindspore {
  27. class FuncGraphIndex {
  28. public:
  29. explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch,
  30. const IncludeFunc &include = AlwaysInclude);
  31. FuncGraphIndex(const FuncGraphIndex &) = delete;
  32. FuncGraphIndex &operator=(const FuncGraphIndex &) = delete;
  33. virtual ~FuncGraphIndex() {}
  34. std::set<FuncGraphPtr> GetFuncGraphs(const std::string &key);
  35. std::set<AnfNodePtr> GetNodes(const std::string &key);
  36. FuncGraphPtr GetFirstFuncGraph(const std::string &key);
  37. AnfNodePtr GetFirstNode(const std::string &key);
  38. private:
  39. void Acquire(const FuncGraphPtr &key);
  40. void Acquire(const AnfNodePtr &key);
  41. std::map<std::string, std::set<FuncGraphPtr>> index_func_graph_;
  42. std::map<std::string, std::set<AnfNodePtr>> index_node_;
  43. };
  44. FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search, const IncludeFunc &include) {
  45. MS_EXCEPTION_IF_NULL(fg);
  46. Acquire(fg);
  47. auto vec = search(fg->get_return(), include);
  48. for (auto &node : vec) {
  49. MS_EXCEPTION_IF_NULL(node);
  50. Acquire(node);
  51. if (node->func_graph() != nullptr) {
  52. Acquire(node->func_graph());
  53. }
  54. }
  55. }
  56. std::set<FuncGraphPtr> FuncGraphIndex::GetFuncGraphs(const std::string &key) {
  57. std::set<FuncGraphPtr> func_graphs;
  58. if (index_func_graph_.find(key) != index_func_graph_.end()) {
  59. func_graphs = index_func_graph_[key];
  60. }
  61. return func_graphs;
  62. }
  63. std::set<AnfNodePtr> FuncGraphIndex::GetNodes(const std::string &key) {
  64. if (index_node_.find(key) != index_node_.end()) {
  65. return index_node_[key];
  66. }
  67. return std::set<AnfNodePtr>();
  68. }
  69. FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string &key) {
  70. if (GetFuncGraphs(key).empty()) {
  71. return nullptr;
  72. }
  73. auto fg = *GetFuncGraphs(key).begin();
  74. return fg;
  75. }
  76. AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string &key) {
  77. if (GetNodes(key).empty()) {
  78. return nullptr;
  79. }
  80. auto node = *GetNodes(key).begin();
  81. return node;
  82. }
  83. void FuncGraphIndex::Acquire(const FuncGraphPtr &key) {
  84. std::string name = label_manage::Label(key->debug_info());
  85. if (!name.empty()) {
  86. (void)index_func_graph_[name].insert(key);
  87. }
  88. }
  89. void FuncGraphIndex::Acquire(const AnfNodePtr &key) {
  90. std::string name = label_manage::Label(key->debug_info());
  91. if (!name.empty()) {
  92. (void)index_node_[name].insert(key);
  93. }
  94. }
  95. class TestCloner : public UT::Common {
  96. public:
  97. TestCloner() : getPyFun("gtest_input.ir.clone_test", true) {
  98. one = NewValueNode(static_cast<int64_t>(1));
  99. two = NewValueNode(static_cast<int64_t>(2));
  100. three = NewValueNode(static_cast<int64_t>(3));
  101. }
  102. FuncGraphPtr GraphForInline() { return nullptr; }
  103. void SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig, const std::vector<AnfNodePtr> &params,
  104. FuncGraphPtr target);
  105. public:
  106. UT::PyFuncGraphFetcher getPyFun;
  107. ValueNodePtr one;
  108. ValueNodePtr two;
  109. ValueNodePtr three;
  110. };
  111. void TestCloner::SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig,
  112. const std::vector<AnfNodePtr> &params, FuncGraphPtr target) {
  113. auto g = (*cl)[orig];
  114. ASSERT_TRUE(g != target);
  115. ASSERT_TRUE(g == orig);
  116. auto new_root = (*cl)[orig->output()];
  117. ASSERT_TRUE(new_root != orig->output());
  118. AnfNodeSet orig_nodes = AnfNodeSet(DeepLinkedGraphSearch(orig->output()));
  119. AnfNodeSet new_nodes = AnfNodeSet(DeepLinkedGraphSearch(new_root));
  120. for (auto &p : params) {
  121. ASSERT_TRUE(new_nodes.contains(p));
  122. }
  123. for (auto &node : orig_nodes) {
  124. if (node->func_graph() == orig) {
  125. ASSERT_TRUE((*cl)[node]);
  126. }
  127. }
  128. ASSERT_TRUE(target->output() == three);
  129. }
  130. TEST_F(TestCloner, test_clone_simple) {
  131. std::string py_code = "test_clone_simple";
  132. FuncGraphPtr g = getPyFun.CallAndParseRet(py_code);
  133. ASSERT_TRUE(g != nullptr);
  134. std::vector<FuncGraphPtr> gs = {g};
  135. Cloner cl(gs, true);
  136. auto g2 = cl[g];
  137. AnfNodeSet d1 = AnfNodeSet(DeepScopedGraphSearch(g->get_return()));
  138. AnfNodeSet d2 = AnfNodeSet(DeepScopedGraphSearch(g2->get_return()));
  139. auto common = d1 & d2;
  140. ASSERT_EQ((size_t)0, common.size());
  141. Cloner cl2(gs);
  142. auto g3 = cl2[g];
  143. std::vector<Primitive> results = {Primitive(prim::kScalarAdd), Primitive(prim::kScalarMul), Primitive("Return")};
  144. AnfNodeSet d3 = AnfNodeSet(DeepScopedGraphSearch(g3->get_return()));
  145. common = d1 & d3;
  146. for (auto &x : common) {
  147. ASSERT_TRUE(x->isa<ValueNode>());
  148. ASSERT_TRUE(find(results.begin(), results.end(), *x->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>()) !=
  149. results.end());
  150. }
  151. }
  152. TEST_F(TestCloner, test_clone_closure) {
  153. std::string py_code = "test_clone_closure";
  154. // parse ast to graph
  155. FuncGraphPtr parsed_f = getPyFun(py_code);
  156. FuncGraphIndex idx(parsed_f);
  157. auto g = idx.GetFirstFuncGraph("j");
  158. std::vector<FuncGraphPtr> gs = {g};
  159. Cloner cl(gs, true);
  160. auto g_clone = cl[g];
  161. FuncGraphIndex idx2(g_clone, DeepLinkedGraphSearch);
  162. std::string name_list = "xy";
  163. for (auto name : name_list) {
  164. ASSERT_EQ(idx.GetFirstNode(std::string(1, name)), idx2.GetFirstNode(std::string(1, name)));
  165. }
  166. ASSERT_FALSE(idx.GetFirstNode("z") == idx2.GetFirstNode("z"));
  167. ASSERT_FALSE(idx.GetFirstFuncGraph("j") == idx2.GetFirstFuncGraph("j"));
  168. }
  169. TEST_F(TestCloner, test_clone_lifting) {
  170. std::string py_code = "test_clone_closure";
  171. // parse ast to graph
  172. FuncGraphPtr parsed_f = getPyFun(py_code);
  173. auto g_lifting = LiftingClone(parsed_f);
  174. FuncGraphIndex idx(g_lifting);
  175. auto g = idx.GetFirstFuncGraph("j");
  176. auto params = g_lifting->parameters();
  177. auto child_params = g->parameters();
  178. ASSERT_TRUE(params.size() + 1 == child_params.size());
  179. }
  180. TEST_F(TestCloner, test_clone_scoping) {
  181. std::string py_code = "test_clone_scoping";
  182. // parse ast to graph
  183. FuncGraphPtr g = getPyFun.CallAndParseRet(py_code);
  184. std::vector<FuncGraphPtr> gs = {g};
  185. Cloner cl(gs, true);
  186. auto g2 = cl[g];
  187. FuncGraphIndex idx1(g);
  188. FuncGraphIndex idx2(g2);
  189. std::string name_list = "fgi";
  190. for (auto name : name_list) {
  191. auto result1 = idx1.GetFirstFuncGraph(std::string(1, name));
  192. auto result2 = idx2.GetFirstFuncGraph(std::string(1, name));
  193. ASSERT_FALSE(result1 == result2);
  194. }
  195. name_list = "h";
  196. for (auto name : name_list) {
  197. ASSERT_TRUE(idx1.GetFirstFuncGraph(std::string(1, name)) == idx2.GetFirstFuncGraph(std::string(1, name)));
  198. }
  199. }
  200. TEST_F(TestCloner, test_clone_total) {
  201. std::string py_code = "test_clone_total";
  202. // parse ast to graph
  203. getPyFun.SetDoResolve();
  204. FuncGraphPtr g = getPyFun.CallAndParseRet(py_code);
  205. if (g == nullptr) {
  206. return;
  207. }
  208. FuncGraphIndex idx0(g);
  209. std::vector<FuncGraphPtr> gs = {g};
  210. Cloner cl1(gs, true, true, true);
  211. auto g2 = cl1[g];
  212. FuncGraphIndex idx1(g2);
  213. ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total_sub") == idx1.GetFirstFuncGraph("clone_total_sub"));
  214. ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total") == idx1.GetFirstFuncGraph("clone_total"));
  215. Cloner cl2(gs, true);
  216. FuncGraphIndex idx2(cl2[g]);
  217. ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total") == idx2.GetFirstFuncGraph("clone_total"));
  218. ASSERT_TRUE(idx0.GetFirstFuncGraph("clone_total_sub") == idx2.GetFirstFuncGraph("clone_total_sub"));
  219. }
  220. } // namespace mindspore