You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

manager_test.cc 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/common_test.h"
  17. #include "common/py_func_graph_fetcher.h"
  18. #include "ir/dtype.h"
  19. #include "ir/manager.h"
  20. #include "ir/func_graph_cloner.h"
  21. #include "pipeline/parse/parse.h"
  22. #include "operator/ops.h"
  23. #include "utils/log_adapter.h"
  24. #include "debug/draw.h"
  25. #include "debug/label.h"
  26. #include "./common.h"
  27. namespace mindspore {
  28. namespace {
  29. std::vector<std::string> SplitString(std::string str, std::string pattern) {
  30. std::string::size_type pos;
  31. std::vector<std::string> result;
  32. str += pattern;
  33. std::string::size_type size = str.size();
  34. for (std::string::size_type i = 0; i < size; ++i) {
  35. pos = str.find(pattern, i);
  36. if (pos < size) {
  37. std::string s = str.substr(i, pos - i);
  38. result.push_back(s);
  39. i = pos + pattern.size() - 1;
  40. }
  41. }
  42. return result;
  43. }
  44. } // namespace
  45. using std::dynamic_pointer_cast;
  46. using TodoList = std::vector<std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>>;
  47. using TodoListItem = std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>;
  48. class NestingSpecs;
  49. class Stage {
  50. public:
  51. explicit Stage(std::vector<std::string> specs) {
  52. for (auto arg : specs) {
  53. auto spec = SplitString(arg, "=");
  54. if (spec.size() <= 1) {
  55. continue;
  56. }
  57. std::shared_ptr<NestingSpecs> nesting = std::make_shared<NestingSpecs>(this, spec[1]);
  58. specs_[ToFullString(spec[0])] = nesting;
  59. }
  60. }
  61. ~Stage() {}
  62. std::map<std::string, std::string>& subs() { return subs_; }
  63. void set_subs(const std::map<std::string, std::string>& subs) { subs_ = subs; }
  64. private:
  65. std::string ToFullString(std::string s) {
  66. if (s.find("fv") != std::string::npos) {
  67. s = s.replace(s.find("fv"), 2, "free_variable");
  68. }
  69. if (s.find("deps") != std::string::npos) {
  70. s = s.replace(s.find("deps"), 4, "dependencies");
  71. }
  72. return s;
  73. }
  74. std::map<std::string, std::shared_ptr<NestingSpecs>> specs_;
  75. std::map<std::string, std::string> subs_;
  76. };
  77. class NestingSpecs {
  78. public:
  79. NestingSpecs(Stage* stage, std::string specs) : stage_(stage) { ParseSpecs(specs); }
  80. ~NestingSpecs() {}
  81. std::string Name(Any node) {
  82. std::string name = label_manage::Label(node.cast<AnfNodePtr>()->debug_info());
  83. if (stage_->subs().find(name) != stage_->subs().end()) {
  84. return stage_->subs()[name];
  85. }
  86. return name;
  87. }
  88. void Check(std::shared_ptr<FuncGraphAnalysis> results) {
  89. if (expected_.empty() && expected_recursive_.empty()) {
  90. return;
  91. }
  92. auto parent = dynamic_pointer_cast<ParentComputer>(results);
  93. if (parent != nullptr) {
  94. CheckParent(parent);
  95. return;
  96. }
  97. auto recursive = dynamic_pointer_cast<RecursiveComputer>(results);
  98. if (recursive != nullptr) {
  99. CheckRecursive(recursive);
  100. return;
  101. }
  102. auto counter_g = dynamic_pointer_cast<CounterFuncGraphCollector>(results);
  103. if (counter_g != nullptr) {
  104. CheckGraphCounter(counter_g);
  105. return;
  106. }
  107. auto counter_p = dynamic_pointer_cast<CounterAnfNodeCollector<AnfNodePtr>>(results);
  108. if (counter_p != nullptr) {
  109. CheckAnfNodeCounter(counter_p);
  110. return;
  111. }
  112. }
  113. private:
  114. void ParseSpecs(std::string specs) {
  115. if (specs.empty()) {
  116. return;
  117. }
  118. std::vector<std::string> str_list = SplitString(specs, ";");
  119. for (auto spec : str_list) {
  120. spec.erase(0, spec.find_first_not_of(" "));
  121. spec.erase(spec.find_last_not_of(" ") + 1);
  122. if (spec.empty()) {
  123. continue;
  124. }
  125. if (spec.find("->") != std::string::npos) {
  126. auto substr = SplitString(spec, "->");
  127. ASSERT_GT(substr.size(), 1);
  128. auto key = substr[0];
  129. auto value = substr[1];
  130. if (!value.empty()) {
  131. expected_[key] = {value};
  132. }
  133. } else if (spec.find(":") != std::string::npos) {
  134. auto substr = SplitString(spec, ":");
  135. ASSERT_GT(substr.size(), 1);
  136. auto key = substr[0];
  137. auto values = SplitString(substr[1], ",");
  138. std::set<std::string> values_set(values.begin(), values.end());
  139. if (!values_set.empty()) {
  140. expected_[key] = values_set;
  141. }
  142. } else {
  143. expected_recursive_[spec] = true;
  144. }
  145. }
  146. }
  147. void CheckParent(std::shared_ptr<ParentComputer> results) {
  148. std::map<std::string, std::set<std::string>> clean_results;
  149. for (auto& iter : results->parent_analysis()) {
  150. auto key = iter.first;
  151. auto value = iter.second;
  152. if (key == nullptr) {
  153. continue;
  154. }
  155. std::string k = Name(key);
  156. std::set<std::string> v;
  157. if (value != nullptr && !Name(value).empty()) {
  158. v.insert(Name(value));
  159. }
  160. if (!v.empty()) {
  161. clean_results[k] = v;
  162. }
  163. }
  164. ASSERT_EQ(clean_results, expected_);
  165. }
  166. // Add CheckNesting function
  167. void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) {
  168. std::map<std::string, std::set<std::string>> clean_results;
  169. for (auto& iter : results->count_nodes_map()) {
  170. auto key = iter.first;
  171. auto value = iter.second;
  172. if (key == nullptr) {
  173. continue;
  174. }
  175. std::string k = Name(key);
  176. std::set<std::string> v;
  177. for (auto& node : value) {
  178. auto fg = node.first;
  179. if (!Name(fg).empty()) {
  180. v.insert(Name(fg));
  181. }
  182. }
  183. if (!v.empty()) {
  184. clean_results[k] = v;
  185. }
  186. }
  187. ASSERT_EQ(clean_results, expected_);
  188. }
  189. void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) {
  190. std::map<std::string, std::set<std::string>> clean_results;
  191. for (auto& iter : results->count_func_graphs_map()) {
  192. auto key = iter.first;
  193. auto value = iter.second;
  194. if (key == nullptr) {
  195. continue;
  196. }
  197. std::string k = Name(key);
  198. std::set<std::string> v;
  199. for (auto& node : value) {
  200. auto fg = node.first;
  201. if (!Name(fg).empty()) {
  202. v.insert(Name(fg));
  203. }
  204. }
  205. if (!v.empty()) {
  206. clean_results[k] = v;
  207. }
  208. }
  209. ASSERT_EQ(clean_results, expected_);
  210. }
  211. void CheckRecursive(std::shared_ptr<RecursiveComputer> results) {
  212. std::map<std::string, bool> clean_results;
  213. for (auto iter = results->recursive_analysis().begin(); iter != results->recursive_analysis().end(); ++iter) {
  214. auto key = iter->first;
  215. auto value = iter->second;
  216. if (key == nullptr) {
  217. continue;
  218. }
  219. std::string k = Name(key);
  220. clean_results[k] = value;
  221. }
  222. ASSERT_EQ(clean_results, expected_recursive_);
  223. }
  224. private:
  225. Stage* stage_;
  226. std::map<std::string, std::set<std::string>> expected_;
  227. std::map<std::string, bool> expected_recursive_;
  228. };
  229. bool CheckUsers(std::shared_ptr<FuncGraphManager> manager) {
  230. for (auto node : manager->all_nodes()) {
  231. if (node->isa<CNode>()) {
  232. auto& inputs = node->cast<CNodePtr>()->inputs();
  233. for (size_t i = 0; i < inputs.size(); ++i) {
  234. auto inp = inputs[i];
  235. if (!manager->all_nodes().contains(inp)) {
  236. return false;
  237. }
  238. if (manager->node_users().find(inp) != manager->node_users().end()) {
  239. auto users = manager->node_users()[inp];
  240. if (!users.contains(make_pair(node, i))) {
  241. return false;
  242. }
  243. }
  244. }
  245. }
  246. if (manager->node_users().find(node) != manager->node_users().end()) {
  247. auto users = manager->node_users()[node];
  248. for (auto iter = users.begin(); iter != users.end(); ++iter) {
  249. auto node2 = iter->first;
  250. auto key = iter->second;
  251. if (!manager->all_nodes().contains(node2)) {
  252. return false;
  253. }
  254. if (node2->cast<CNodePtr>()->input(key) != node) {
  255. return false;
  256. }
  257. }
  258. }
  259. }
  260. return true;
  261. }
  262. class TestManager : public UT::Common {
  263. public:
  264. TestManager() : getPyFun("gtest_input.ir.manager_test") {}
  265. void CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng);
  266. public:
  267. std::vector<PrimitivePtr> swaps;
  268. UT::PyFuncGraphFetcher getPyFun;
  269. };
  270. FuncGraphPtr MakeFuncGraph(PrimitivePtr prim) {
  271. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  272. ParameterPtr x = func_graph->add_parameter();
  273. ParameterPtr y = func_graph->add_parameter();
  274. std::vector<AnfNodePtr> inputs;
  275. inputs.push_back(NewValueNode(prim));
  276. inputs.push_back(x);
  277. inputs.push_back(y);
  278. CNodePtr cnode_add = func_graph->NewCNode(inputs);
  279. inputs.clear();
  280. inputs.push_back(NewValueNode(prim::kPrimReturn));
  281. inputs.push_back(cnode_add);
  282. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  283. func_graph->set_return(cnode_return);
  284. return func_graph;
  285. }
  286. std::vector<FuncGraphPtr> MakeNestedGraph() {
  287. /*
  288. *def f(x):
  289. * def g():
  290. * return x
  291. * return g
  292. */
  293. FuncGraphPtr f = std::make_shared<FuncGraph>();
  294. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  295. ParameterPtr x = f->add_parameter();
  296. std::vector<AnfNodePtr> inputs;
  297. inputs.push_back(NewValueNode(fg));
  298. inputs.push_back(NewValueNode(prim::kPrimReturn));
  299. CNodePtr cnode_f = f->NewCNode(inputs);
  300. f->set_return(cnode_f);
  301. inputs.clear();
  302. inputs.push_back(NewValueNode(prim::kPrimReturn));
  303. inputs.push_back(x);
  304. CNodePtr cnode_g = fg->NewCNode(inputs);
  305. fg->set_return(cnode_g);
  306. std::vector<FuncGraphPtr> result = {f, fg};
  307. return result;
  308. }
  309. std::vector<FuncGraphPtr> MakeNestedGraph2() {
  310. /* build a closure func_graph */
  311. /*
  312. *def foo(x, y):
  313. * def bar(x1):
  314. * return x1 + y
  315. * return bar(x)
  316. */
  317. FuncGraphPtr graph_foo = std::make_shared<FuncGraph>();
  318. ParameterPtr x = graph_foo->add_parameter();
  319. ParameterPtr y = graph_foo->add_parameter();
  320. std::vector<AnfNodePtr> inputs;
  321. // build func_graph bar
  322. FuncGraphPtr graph_bar = std::make_shared<FuncGraph>();
  323. ParameterPtr x1 = graph_bar->add_parameter();
  324. inputs.clear();
  325. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  326. inputs.push_back(x1);
  327. inputs.push_back(y);
  328. CNodePtr cnode_add = graph_bar->NewCNode(inputs);
  329. inputs.clear();
  330. inputs.push_back(NewValueNode(prim::kPrimReturn));
  331. inputs.push_back(cnode_add);
  332. CNodePtr cnode_return = graph_bar->NewCNode(inputs);
  333. graph_bar->set_return(cnode_return);
  334. // build func_graph foo
  335. inputs.clear();
  336. inputs.push_back(NewValueNode(graph_bar));
  337. inputs.push_back(x);
  338. CNodePtr cnode_graph_bar = graph_foo->NewCNode(inputs);
  339. inputs.clear();
  340. inputs.push_back(NewValueNode(prim::kPrimReturn));
  341. inputs.push_back(cnode_graph_bar);
  342. cnode_return = graph_foo->NewCNode(inputs);
  343. graph_foo->set_return(cnode_return);
  344. std::vector<FuncGraphPtr> result = {graph_foo, graph_bar};
  345. return result;
  346. }
  347. // Add TestManager::CheckManager function to checkout the result
  348. void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) {
  349. auto size = mng->func_graphs().size();
  350. ASSERT_EQ(size, mng->free_variables_total().size());
  351. }
  352. TEST_F(TestManager, test_scalar_add_manual) {
  353. auto prim_scalar_add = prim::kPrimScalarAdd;
  354. FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
  355. auto mng = Manage(func_graph);
  356. }
  357. TEST_F(TestManager, test_scalar_replace) {
  358. auto prim_scalar_add = prim::kPrimScalarAdd;
  359. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  360. ParameterPtr x = func_graph->add_parameter();
  361. ParameterPtr y = func_graph->add_parameter();
  362. std::vector<AnfNodePtr> inputs;
  363. inputs.push_back(NewValueNode(prim_scalar_add));
  364. inputs.push_back(x);
  365. inputs.push_back(y);
  366. CNodePtr cnode_add = func_graph->NewCNode(inputs);
  367. inputs.clear();
  368. inputs.push_back(NewValueNode(prim::kPrimReturn));
  369. inputs.push_back(cnode_add);
  370. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  371. func_graph->set_return(cnode_return);
  372. auto mng = Manage(func_graph);
  373. std::cout << "start " << x->ToString() << std::endl;
  374. mng->Replace(cnode_add, x);
  375. }
  376. TEST_F(TestManager, test_nested_manual) {
  377. auto graphs = MakeNestedGraph();
  378. auto f = graphs[0];
  379. auto g = graphs[1];
  380. auto mng = Manage(f);
  381. ASSERT_EQ(6, mng->all_nodes().size());
  382. ASSERT_EQ(2, mng->func_graphs().size());
  383. ASSERT_EQ(4, mng->node_users().size());
  384. ASSERT_EQ(1, mng->roots().size());
  385. CheckAnalysisSize(mng);
  386. ASSERT_EQ(2, f->nodes().size());
  387. ASSERT_EQ(1, g->nodes().size());
  388. auto users = mng->node_users();
  389. for (auto& iter : users) {
  390. ASSERT_EQ(1, iter.second.size());
  391. }
  392. ASSERT_EQ(1, f->func_graphs_used().size());
  393. ASSERT_EQ(0, g->func_graphs_used().size());
  394. ASSERT_EQ(0, f->free_variables().size());
  395. ASSERT_EQ(1, g->free_variables().size());
  396. auto fv_total = mng->free_variables_total();
  397. ASSERT_EQ(0, fv_total[f].size());
  398. ASSERT_EQ(1, fv_total[g].size());
  399. ASSERT_EQ(0, f->func_graph_cnodes_index().size());
  400. ASSERT_EQ(1, g->func_graph_cnodes_index().size());
  401. }
  402. TEST_F(TestManager, test_deep_nested2_manual) {
  403. // create parser
  404. FuncGraphPtr func_graph = getPyFun("test_custom");
  405. return;
  406. // parse ast to func graph
  407. FuncGraphPtr gfn = BasicClone(func_graph);
  408. if (gfn == nullptr) {
  409. return;
  410. }
  411. auto mng = Manage(gfn);
  412. ASSERT_EQ(3, mng->func_graphs().size());
  413. ASSERT_EQ(1, mng->roots().size());
  414. ASSERT_EQ(4, gfn->nodes().size());
  415. ASSERT_EQ(20, mng->all_nodes().size());
  416. ASSERT_EQ(25, mng->node_users().size());
  417. CheckAnalysisSize(mng);
  418. }
  419. TEST_F(TestManager, test_deep_nested_manual) {
  420. FuncGraphPtr f = std::make_shared<FuncGraph>();
  421. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  422. FuncGraphPtr h = std::make_shared<FuncGraph>();
  423. ParameterPtr x = f->add_parameter();
  424. ParameterPtr y = f->add_parameter();
  425. ParameterPtr z = f->add_parameter();
  426. std::vector<AnfNodePtr> inputs;
  427. inputs.push_back(NewValueNode(fg));
  428. inputs.push_back(x);
  429. inputs.push_back(y);
  430. CNodePtr cnode_1 = f->NewCNode(inputs);
  431. inputs.clear();
  432. inputs.push_back(cnode_1);
  433. inputs.push_back(NewValueNode(prim::kPrimReturn));
  434. CNodePtr cnode_0 = f->NewCNode(inputs);
  435. f->set_return(cnode_0);
  436. ParameterPtr x1 = fg->add_parameter();
  437. ParameterPtr y1 = fg->add_parameter();
  438. inputs.clear();
  439. inputs.push_back(NewValueNode(h));
  440. inputs.push_back(x1);
  441. CNodePtr cnode_3 = fg->NewCNode(inputs);
  442. inputs.clear();
  443. inputs.push_back(cnode_3);
  444. inputs.push_back(NewValueNode(prim::kPrimReturn));
  445. CNodePtr cnode_2 = fg->NewCNode(inputs);
  446. fg->set_return(cnode_2);
  447. ParameterPtr x2 = h->add_parameter();
  448. inputs.clear();
  449. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  450. inputs.push_back(x2);
  451. inputs.push_back(y1);
  452. CNodePtr cnode_6 = h->NewCNode(inputs);
  453. inputs.clear();
  454. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  455. inputs.push_back(z);
  456. inputs.push_back(cnode_6);
  457. CNodePtr cnode_5 = h->NewCNode(inputs);
  458. inputs.clear();
  459. inputs.push_back(cnode_5);
  460. inputs.push_back(NewValueNode(prim::kPrimReturn));
  461. CNodePtr cnode_4 = h->NewCNode(inputs);
  462. h->set_return(cnode_4);
  463. auto mng = Manage(f);
  464. ASSERT_EQ(3, mng->func_graphs().size());
  465. ASSERT_EQ(1, mng->roots().size());
  466. ASSERT_EQ(20, mng->all_nodes().size());
  467. CheckAnalysisSize(mng);
  468. }
  469. TEST_F(TestManager, test_parent1_manual) {
  470. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  471. Parameter param(fg);
  472. std::vector<AnfNodePtr> params;
  473. CNodePtr app = std::make_shared<CNode>(params, fg);
  474. fg->set_return(app);
  475. fg->set_parameters(params);
  476. std::shared_ptr<FuncGraphManager> manager = MakeManager();
  477. manager->AddFuncGraph(fg, true);
  478. FuncGraphPtr p = fg->parent();
  479. assert(p == nullptr);
  480. }
  481. TEST_F(TestManager, test_parent_manual) {
  482. auto prim_scalar_add = prim::kPrimScalarAdd;
  483. FuncGraphPtr fg = MakeFuncGraph(prim_scalar_add);
  484. std::shared_ptr<FuncGraphManager> manager = MakeManager();
  485. manager->AddFuncGraph(fg);
  486. FuncGraphPtr p = fg->parent();
  487. assert(p == nullptr);
  488. }
  489. TEST_F(TestManager, test_flat) {
  490. std::vector<std::shared_ptr<Stage>> stages;
  491. std::vector<std::string> specs = {"nodes=X:x", "parents=", "fvs_direct="};
  492. std::map<std::string, int> size_list;
  493. size_list["nodes"] = 2;
  494. }
  495. TEST_F(TestManager, test_nested) {
  496. std::vector<std::shared_ptr<Stage>> stages;
  497. std::vector<std::string> specs = {"nodes=X:x", "parent=g->X", "fvs_direct=g:x"};
  498. std::map<std::string, int> size_list;
  499. return;
  500. }
  501. TEST_F(TestManager, test_calls) {
  502. std::vector<std::shared_ptr<Stage>> stages;
  503. std::vector<std::string> specs = {"parents=g->X; h->X", "children=X:g,h", "scopes=X:X,g,h; g:g; h:h",
  504. "fvs_direct=h:a", "fvs_total=h:a; g:h"};
  505. std::map<std::string, int> size_list;
  506. return;
  507. }
  508. TEST_F(TestManager, test_unused_param) {
  509. std::vector<std::shared_ptr<Stage>> stages;
  510. std::vector<std::string> specs = {"nodes=X:x,y"};
  511. std::map<std::string, int> size_list;
  512. }
  513. TEST_F(TestManager, test_cannot_replace_return) {
  514. FuncGraphPtr fg = getPyFun("test_cannot_replace_return");
  515. ASSERT_NE(fg, nullptr);
  516. auto mng = Manage(fg);
  517. ASSERT_EQ(fg->manager(), mng);
  518. ASSERT_NE(mng, nullptr);
  519. ASSERT_GT(fg->parameters().size(), 0);
  520. ASSERT_FALSE(mng->Replace(fg->get_return(), fg->parameters()[0]));
  521. }
  522. TEST_F(TestManager, test_weak_manager) {
  523. FuncGraphPtr fg = getPyFun("ir_get_fn");
  524. auto mng1 = MakeManager({fg}, false);
  525. ASSERT_EQ(fg->manager(), nullptr);
  526. auto mng2 = MakeManager({fg}, true);
  527. ASSERT_EQ(fg->manager(), mng2);
  528. auto mng3 = MakeManager({fg}, false);
  529. ASSERT_EQ(fg->manager(), mng2);
  530. }
  531. TEST_F(TestManager, test_drop_root) {
  532. FuncGraphPtr fg = getPyFun("ir_get_fn");
  533. auto mng = Manage(fg);
  534. const auto &fgs = mng->func_graphs();
  535. ASSERT_TRUE(fgs.contains(fg));
  536. FuncGraphSet s;
  537. s.add(fg);
  538. mng->MaybeDropFuncGraphs(s);
  539. ASSERT_TRUE(fgs.contains(fg));
  540. }
  541. TEST_F(TestManager, test_keep_roots) {
  542. FuncGraphPtr fg1 = getPyFun("ir_get_fn");
  543. FuncGraphPtr fg2 = getPyFun("test_cannot_replace_return");
  544. auto mng = Manage(fg1);
  545. ASSERT_EQ(mng->func_graphs().size(), (size_t)1);
  546. ASSERT_TRUE(mng->func_graphs().contains(fg1));
  547. mng->AddFuncGraph(fg2);
  548. ASSERT_EQ(mng->func_graphs().size(), 2);
  549. ASSERT_TRUE(mng->func_graphs().contains(fg2));
  550. mng->KeepRoots();
  551. ASSERT_EQ(mng->func_graphs().size(), 1);
  552. ASSERT_TRUE(mng->func_graphs().contains(fg1));
  553. mng->KeepRoots({fg2});
  554. ASSERT_EQ(mng->func_graphs().size(), 1);
  555. ASSERT_TRUE(mng->func_graphs().contains(fg2));
  556. }
  557. TEST_F(TestManager, test_keep_roots_recursion) {
  558. return;
  559. FuncGraphPtr fg = getPyFun("test_keep_roots_recursion");
  560. ASSERT_NE(fg, nullptr);
  561. auto mng = Manage(fg);
  562. parse::ResolveAll(mng);
  563. ASSERT_NE(mng, nullptr);
  564. ASSERT_EQ(mng->func_graphs().size(), 4);
  565. ASSERT_GT(fg->parameters().size(), 0);
  566. mng->Replace(fg->output(), fg->parameters()[0]);
  567. ASSERT_EQ(mng->func_graphs().size(), 3);
  568. mng->KeepRoots();
  569. ASSERT_EQ(mng->func_graphs().size(), 1);
  570. }
  571. } // namespace mindspore