You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

manager_test.cc 21 kB


  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/common_test.h"
  17. #include "common/py_func_graph_fetcher.h"
  18. #include "ir/dtype.h"
  19. #include "ir/manager.h"
  20. #include "ir/func_graph_cloner.h"
  21. #include "pipeline/parse/parse.h"
  22. #include "operator/ops.h"
  23. #include "utils/log_adapter.h"
  24. #include "debug/draw.h"
  25. #include "debug/label.h"
  26. #include "./common.h"
  27. namespace mindspore {
  28. namespace {
  29. std::vector<std::string> SplitString(std::string str, std::string pattern) {
  30. std::string::size_type pos;
  31. std::vector<std::string> result;
  32. str += pattern;
  33. std::string::size_type size = str.size();
  34. for (std::string::size_type i = 0; i < size; ++i) {
  35. pos = str.find(pattern, i);
  36. if (pos < size) {
  37. std::string s = str.substr(i, pos - i);
  38. result.push_back(s);
  39. i = pos + pattern.size() - 1;
  40. }
  41. }
  42. return result;
  43. }
  44. } // namespace
  45. using std::dynamic_pointer_cast;
  46. using TodoList = std::vector<std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>>;
  47. using TodoListItem = std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>;
  48. class NestingSpecs;
  49. class Stage {
  50. public:
  51. explicit Stage(std::vector<std::string> specs) {
  52. for (auto arg : specs) {
  53. auto spec = SplitString(arg, "=");
  54. if (spec.size() <= 1) {
  55. continue;
  56. }
  57. std::shared_ptr<NestingSpecs> nesting = std::make_shared<NestingSpecs>(this, spec[1]);
  58. specs_[ToFullString(spec[0])] = nesting;
  59. }
  60. }
  61. ~Stage() {}
  62. std::map<std::string, std::string>& subs() { return subs_; }
  63. void set_subs(const std::map<std::string, std::string>& subs) { subs_ = subs; }
  64. private:
  65. std::string ToFullString(std::string s) {
  66. if (s.find("fv") != std::string::npos) {
  67. s = s.replace(s.find("fv"), 2, "free_variable");
  68. }
  69. if (s.find("deps") != std::string::npos) {
  70. s = s.replace(s.find("deps"), 4, "dependencies");
  71. }
  72. return s;
  73. }
  74. std::map<std::string, std::shared_ptr<NestingSpecs>> specs_;
  75. std::map<std::string, std::string> subs_;
  76. };
  77. class NestingSpecs {
  78. public:
  79. NestingSpecs(Stage* stage, std::string specs) : stage_(stage) { ParseSpecs(specs); }
  80. ~NestingSpecs() {}
  81. std::string Name(Any node) {
  82. std::string name = label_manage::Label(node.cast<AnfNodePtr>()->debug_info());
  83. if (stage_->subs().find(name) != stage_->subs().end()) {
  84. return stage_->subs()[name];
  85. }
  86. return name;
  87. }
  88. void Check(std::shared_ptr<FuncGraphAnalysis> results) {
  89. if (expected_.empty() && expected_recursive_.empty()) {
  90. return;
  91. }
  92. auto parent = dynamic_pointer_cast<ParentComputer>(results);
  93. if (parent != nullptr) {
  94. CheckParent(parent);
  95. return;
  96. }
  97. auto recursive = dynamic_pointer_cast<RecursiveComputer>(results);
  98. if (recursive != nullptr) {
  99. CheckRecursive(recursive);
  100. return;
  101. }
  102. auto counter_g = dynamic_pointer_cast<CounterFuncGraphCollector>(results);
  103. if (counter_g != nullptr) {
  104. CheckGraphCounter(counter_g);
  105. return;
  106. }
  107. auto counter_p = dynamic_pointer_cast<CounterAnfNodeCollector>(results);
  108. if (counter_p != nullptr) {
  109. CheckAnfNodeCounter(counter_p);
  110. return;
  111. }
  112. auto nodes = dynamic_pointer_cast<NodesCollector>(results);
  113. if (nodes != nullptr) {
  114. CheckNodes(nodes);
  115. return;
  116. }
  117. }
  118. private:
  119. void ParseSpecs(std::string specs) {
  120. if (specs.empty()) {
  121. return;
  122. }
  123. std::vector<std::string> str_list = SplitString(specs, ";");
  124. for (auto spec : str_list) {
  125. spec.erase(0, spec.find_first_not_of(" "));
  126. spec.erase(spec.find_last_not_of(" ") + 1);
  127. if (spec.empty()) {
  128. continue;
  129. }
  130. if (spec.find("->") != std::string::npos) {
  131. auto substr = SplitString(spec, "->");
  132. ASSERT_GT(substr.size(), 1);
  133. auto key = substr[0];
  134. auto value = substr[1];
  135. if (!value.empty()) {
  136. expected_[key] = {value};
  137. }
  138. } else if (spec.find(":") != std::string::npos) {
  139. auto substr = SplitString(spec, ":");
  140. ASSERT_GT(substr.size(), 1);
  141. auto key = substr[0];
  142. auto values = SplitString(substr[1], ",");
  143. std::set<std::string> values_set(values.begin(), values.end());
  144. if (!values_set.empty()) {
  145. expected_[key] = values_set;
  146. }
  147. } else {
  148. expected_recursive_[spec] = true;
  149. }
  150. }
  151. }
  152. void CheckParent(std::shared_ptr<ParentComputer> results) {
  153. std::map<std::string, std::set<std::string>> clean_results;
  154. for (auto& iter : results->parent_analysis()) {
  155. auto key = iter.first;
  156. auto value = iter.second;
  157. if (key == nullptr) {
  158. continue;
  159. }
  160. std::string k = Name(key);
  161. std::set<std::string> v;
  162. if (value != nullptr && !Name(value).empty()) {
  163. v.insert(Name(value));
  164. }
  165. if (!v.empty()) {
  166. clean_results[k] = v;
  167. }
  168. }
  169. ASSERT_EQ(clean_results, expected_);
  170. }
  171. void CheckNodes(std::shared_ptr<NodesCollector> results) {
  172. std::map<std::string, std::set<std::string>> clean_results;
  173. for (auto& iter : results->nodes_analysis()) {
  174. auto key = iter.first;
  175. auto value = iter.second;
  176. if (key == nullptr) {
  177. continue;
  178. }
  179. std::string k = Name(key);
  180. std::set<std::string> v;
  181. for (auto& node : value) {
  182. if (!node->isa<CNode>() && !Name(node).empty()) {
  183. v.insert(Name(node));
  184. }
  185. }
  186. if (!v.empty()) {
  187. clean_results[k] = v;
  188. }
  189. }
  190. ASSERT_EQ(clean_results, expected_);
  191. }
  192. // Add CheckNesting function
  193. void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector> results) {
  194. std::map<std::string, std::set<std::string>> clean_results;
  195. for (auto& iter : results->count_nodes_map()) {
  196. auto key = iter.first;
  197. auto value = iter.second;
  198. if (key == nullptr) {
  199. continue;
  200. }
  201. std::string k = Name(key);
  202. std::set<std::string> v;
  203. for (auto& node : value) {
  204. auto fg = node.first;
  205. if (!Name(fg).empty()) {
  206. v.insert(Name(fg));
  207. }
  208. }
  209. if (!v.empty()) {
  210. clean_results[k] = v;
  211. }
  212. }
  213. ASSERT_EQ(clean_results, expected_);
  214. }
  215. void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) {
  216. std::map<std::string, std::set<std::string>> clean_results;
  217. for (auto& iter : results->count_func_graphs_map()) {
  218. auto key = iter.first;
  219. auto value = iter.second;
  220. if (key == nullptr) {
  221. continue;
  222. }
  223. std::string k = Name(key);
  224. std::set<std::string> v;
  225. for (auto& node : value) {
  226. auto fg = node.first;
  227. if (!Name(fg).empty()) {
  228. v.insert(Name(fg));
  229. }
  230. }
  231. if (!v.empty()) {
  232. clean_results[k] = v;
  233. }
  234. }
  235. ASSERT_EQ(clean_results, expected_);
  236. }
  237. void CheckRecursive(std::shared_ptr<RecursiveComputer> results) {
  238. std::map<std::string, bool> clean_results;
  239. for (auto iter = results->recursive_analysis().begin(); iter != results->recursive_analysis().end(); ++iter) {
  240. auto key = iter->first;
  241. auto value = iter->second;
  242. if (key == nullptr) {
  243. continue;
  244. }
  245. std::string k = Name(key);
  246. clean_results[k] = value;
  247. }
  248. ASSERT_EQ(clean_results, expected_recursive_);
  249. }
  250. private:
  251. Stage* stage_;
  252. std::map<std::string, std::set<std::string>> expected_;
  253. std::map<std::string, bool> expected_recursive_;
  254. };
  255. bool CheckUsers(std::shared_ptr<FuncGraphManager> manager) {
  256. for (auto node : manager->all_nodes()) {
  257. if (node->isa<CNode>()) {
  258. auto& inputs = node->cast<CNodePtr>()->inputs();
  259. for (size_t i = 0; i < inputs.size(); ++i) {
  260. auto inp = inputs[i];
  261. if (!manager->all_nodes().contains(inp)) {
  262. return false;
  263. }
  264. if (manager->node_users().find(inp) != manager->node_users().end()) {
  265. auto users = manager->node_users()[inp];
  266. if (!users.contains(make_pair(node, i))) {
  267. return false;
  268. }
  269. }
  270. }
  271. }
  272. if (manager->node_users().find(node) != manager->node_users().end()) {
  273. auto users = manager->node_users()[node];
  274. for (auto iter = users.begin(); iter != users.end(); ++iter) {
  275. auto node2 = iter->first;
  276. auto key = iter->second;
  277. if (!manager->all_nodes().contains(node2)) {
  278. return false;
  279. }
  280. if (node2->cast<CNodePtr>()->input(key) != node) {
  281. return false;
  282. }
  283. }
  284. }
  285. }
  286. return true;
  287. }
  288. class TestManager : public UT::Common {
  289. public:
  290. TestManager() : getPyFun("gtest_input.ir.manager_test") {}
  291. void CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng);
  292. public:
  293. std::vector<PrimitivePtr> swaps;
  294. UT::PyFuncGraphFetcher getPyFun;
  295. };
  296. FuncGraphPtr MakeFuncGraph(PrimitivePtr prim) {
  297. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  298. ParameterPtr x = func_graph->add_parameter();
  299. ParameterPtr y = func_graph->add_parameter();
  300. std::vector<AnfNodePtr> inputs;
  301. inputs.push_back(NewValueNode(prim));
  302. inputs.push_back(x);
  303. inputs.push_back(y);
  304. CNodePtr cnode_add = func_graph->NewCNode(inputs);
  305. inputs.clear();
  306. inputs.push_back(NewValueNode(prim::kPrimReturn));
  307. inputs.push_back(cnode_add);
  308. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  309. func_graph->set_return(cnode_return);
  310. return func_graph;
  311. }
  312. std::vector<FuncGraphPtr> MakeNestedGraph() {
  313. /*
  314. *def f(x):
  315. * def g():
  316. * return x
  317. * return g
  318. */
  319. FuncGraphPtr f = std::make_shared<FuncGraph>();
  320. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  321. ParameterPtr x = f->add_parameter();
  322. std::vector<AnfNodePtr> inputs;
  323. inputs.push_back(NewValueNode(fg));
  324. inputs.push_back(NewValueNode(prim::kPrimReturn));
  325. CNodePtr cnode_f = f->NewCNode(inputs);
  326. f->set_return(cnode_f);
  327. inputs.clear();
  328. inputs.push_back(NewValueNode(prim::kPrimReturn));
  329. inputs.push_back(x);
  330. CNodePtr cnode_g = fg->NewCNode(inputs);
  331. fg->set_return(cnode_g);
  332. std::vector<FuncGraphPtr> result = {f, fg};
  333. return result;
  334. }
  335. std::vector<FuncGraphPtr> MakeNestedGraph2() {
  336. /* build a closure func_graph */
  337. /*
  338. *def foo(x, y):
  339. * def bar(x1):
  340. * return x1 + y
  341. * return bar(x)
  342. */
  343. FuncGraphPtr graph_foo = std::make_shared<FuncGraph>();
  344. ParameterPtr x = graph_foo->add_parameter();
  345. ParameterPtr y = graph_foo->add_parameter();
  346. std::vector<AnfNodePtr> inputs;
  347. // build func_graph bar
  348. FuncGraphPtr graph_bar = std::make_shared<FuncGraph>();
  349. ParameterPtr x1 = graph_bar->add_parameter();
  350. inputs.clear();
  351. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  352. inputs.push_back(x1);
  353. inputs.push_back(y);
  354. CNodePtr cnode_add = graph_bar->NewCNode(inputs);
  355. inputs.clear();
  356. inputs.push_back(NewValueNode(prim::kPrimReturn));
  357. inputs.push_back(cnode_add);
  358. CNodePtr cnode_return = graph_bar->NewCNode(inputs);
  359. graph_bar->set_return(cnode_return);
  360. // build func_graph foo
  361. inputs.clear();
  362. inputs.push_back(NewValueNode(graph_bar));
  363. inputs.push_back(x);
  364. CNodePtr cnode_graph_bar = graph_foo->NewCNode(inputs);
  365. inputs.clear();
  366. inputs.push_back(NewValueNode(prim::kPrimReturn));
  367. inputs.push_back(cnode_graph_bar);
  368. cnode_return = graph_foo->NewCNode(inputs);
  369. graph_foo->set_return(cnode_return);
  370. std::vector<FuncGraphPtr> result = {graph_foo, graph_bar};
  371. return result;
  372. }
  373. // Add TestManager::CheckManager function to checkout the result
  374. void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) {
  375. auto size = mng->func_graphs().size();
  376. ASSERT_EQ(size + 1, mng->nodes().size());
  377. ASSERT_EQ(size, mng->free_variables_total().size());
  378. ASSERT_EQ(size, mng->valuenodes().size());
  379. ASSERT_EQ(size, mng->free_variables_direct().size());
  380. ASSERT_EQ(size, mng->func_graph_valuenodes().size());
  381. ASSERT_EQ(size, mng->func_graph_parents_direct().size());
  382. ASSERT_EQ(size, mng->func_graph_users().size());
  383. ASSERT_EQ(size, mng->func_graphs_used().size());
  384. }
  385. TEST_F(TestManager, test_scalar_add_manual) {
  386. auto prim_scalar_add = prim::kPrimScalarAdd;
  387. FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
  388. auto mng = Manage(func_graph);
  389. }
  390. TEST_F(TestManager, test_scalar_replace) {
  391. auto prim_scalar_add = prim::kPrimScalarAdd;
  392. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  393. ParameterPtr x = func_graph->add_parameter();
  394. ParameterPtr y = func_graph->add_parameter();
  395. std::vector<AnfNodePtr> inputs;
  396. inputs.push_back(NewValueNode(prim_scalar_add));
  397. inputs.push_back(x);
  398. inputs.push_back(y);
  399. CNodePtr cnode_add = func_graph->NewCNode(inputs);
  400. inputs.clear();
  401. inputs.push_back(NewValueNode(prim::kPrimReturn));
  402. inputs.push_back(cnode_add);
  403. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  404. func_graph->set_return(cnode_return);
  405. auto mng = Manage(func_graph);
  406. std::cout << "start " << x->ToString() << std::endl;
  407. mng->Replace(cnode_add, x);
  408. }
  409. TEST_F(TestManager, test_nested_manual) {
  410. auto graphs = MakeNestedGraph();
  411. auto f = graphs[0];
  412. auto g = graphs[1];
  413. auto mng = Manage(f);
  414. ASSERT_EQ(6, mng->all_nodes().size());
  415. ASSERT_EQ(2, mng->func_graphs().size());
  416. ASSERT_EQ(4, mng->node_users().size());
  417. ASSERT_EQ(1, mng->roots().size());
  418. CheckAnalysisSize(mng);
  419. auto nodes = mng->nodes();
  420. ASSERT_EQ(3, nodes[nullptr].size());
  421. ASSERT_EQ(2, nodes[f].size());
  422. ASSERT_EQ(1, nodes[g].size());
  423. auto users = mng->node_users();
  424. for (auto& iter : users) {
  425. ASSERT_EQ(1, iter.second.size());
  426. }
  427. auto graphs_used = mng->func_graphs_used();
  428. ASSERT_EQ(1, graphs_used[f].size());
  429. ASSERT_EQ(0, graphs_used[g].size());
  430. auto graph_users = mng->func_graph_users();
  431. ASSERT_EQ(0, graph_users[f].size());
  432. ASSERT_EQ(1, graph_users[g].size());
  433. auto fv_direct = mng->free_variables_direct();
  434. ASSERT_EQ(0, fv_direct[f].size());
  435. ASSERT_EQ(1, fv_direct[g].size());
  436. auto fv_total = mng->free_variables_total();
  437. ASSERT_EQ(0, fv_total[f].size());
  438. ASSERT_EQ(1, fv_total[g].size());
  439. auto graph_valuenodes = mng->func_graph_valuenodes();
  440. ASSERT_EQ(0, graph_valuenodes[f].size());
  441. ASSERT_EQ(1, graph_valuenodes[g].size());
  442. }
  443. TEST_F(TestManager, test_deep_nested2_manual) {
  444. // create parser
  445. FuncGraphPtr func_graph = getPyFun("test_custom");
  446. return;
  447. // parse ast to func graph
  448. FuncGraphPtr gfn = BasicClone(func_graph);
  449. if (gfn == nullptr) {
  450. return;
  451. }
  452. auto mng = Manage(gfn);
  453. ASSERT_EQ(3, mng->func_graphs().size());
  454. ASSERT_EQ(1, mng->roots().size());
  455. ASSERT_EQ(4, mng->nodes().size());
  456. ASSERT_EQ(20, mng->all_nodes().size());
  457. ASSERT_EQ(25, mng->node_users().size());
  458. CheckAnalysisSize(mng);
  459. }
  460. TEST_F(TestManager, test_deep_nested_manual) {
  461. FuncGraphPtr f = std::make_shared<FuncGraph>();
  462. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  463. FuncGraphPtr h = std::make_shared<FuncGraph>();
  464. ParameterPtr x = f->add_parameter();
  465. ParameterPtr y = f->add_parameter();
  466. ParameterPtr z = f->add_parameter();
  467. std::vector<AnfNodePtr> inputs;
  468. inputs.push_back(NewValueNode(fg));
  469. inputs.push_back(x);
  470. inputs.push_back(y);
  471. CNodePtr cnode_1 = f->NewCNode(inputs);
  472. inputs.clear();
  473. inputs.push_back(cnode_1);
  474. inputs.push_back(NewValueNode(prim::kPrimReturn));
  475. CNodePtr cnode_0 = f->NewCNode(inputs);
  476. f->set_return(cnode_0);
  477. ParameterPtr x1 = fg->add_parameter();
  478. ParameterPtr y1 = fg->add_parameter();
  479. inputs.clear();
  480. inputs.push_back(NewValueNode(h));
  481. inputs.push_back(x1);
  482. CNodePtr cnode_3 = fg->NewCNode(inputs);
  483. inputs.clear();
  484. inputs.push_back(cnode_3);
  485. inputs.push_back(NewValueNode(prim::kPrimReturn));
  486. CNodePtr cnode_2 = fg->NewCNode(inputs);
  487. fg->set_return(cnode_2);
  488. ParameterPtr x2 = h->add_parameter();
  489. inputs.clear();
  490. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  491. inputs.push_back(x2);
  492. inputs.push_back(y1);
  493. CNodePtr cnode_6 = h->NewCNode(inputs);
  494. inputs.clear();
  495. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  496. inputs.push_back(z);
  497. inputs.push_back(cnode_6);
  498. CNodePtr cnode_5 = h->NewCNode(inputs);
  499. inputs.clear();
  500. inputs.push_back(cnode_5);
  501. inputs.push_back(NewValueNode(prim::kPrimReturn));
  502. CNodePtr cnode_4 = h->NewCNode(inputs);
  503. h->set_return(cnode_4);
  504. auto mng = Manage(f);
  505. ASSERT_EQ(3, mng->func_graphs().size());
  506. ASSERT_EQ(1, mng->roots().size());
  507. ASSERT_EQ(4, mng->nodes().size());
  508. ASSERT_EQ(20, mng->all_nodes().size());
  509. CheckAnalysisSize(mng);
  510. }
  511. TEST_F(TestManager, test_parent1_manual) {
  512. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  513. Parameter param(fg);
  514. std::vector<AnfNodePtr> params;
  515. CNodePtr app = std::make_shared<CNode>(params, fg);
  516. fg->set_return(app);
  517. fg->set_parameters(params);
  518. std::shared_ptr<FuncGraphManager> manager = MakeManager();
  519. manager->AddFuncGraph(fg, true);
  520. FuncGraphPtr p = fg->parent();
  521. assert(p == nullptr);
  522. }
  523. TEST_F(TestManager, test_parent_manual) {
  524. auto prim_scalar_add = prim::kPrimScalarAdd;
  525. FuncGraphPtr fg = MakeFuncGraph(prim_scalar_add);
  526. std::shared_ptr<FuncGraphManager> manager = MakeManager();
  527. manager->AddFuncGraph(fg);
  528. FuncGraphPtr p = fg->parent();
  529. assert(p == nullptr);
  530. }
  531. TEST_F(TestManager, test_flat) {
  532. std::vector<std::shared_ptr<Stage>> stages;
  533. std::vector<std::string> specs = {"nodes=X:x", "parents=", "fvs_direct="};
  534. std::map<std::string, int> size_list;
  535. size_list["nodes"] = 2;
  536. }
  537. TEST_F(TestManager, test_nested) {
  538. std::vector<std::shared_ptr<Stage>> stages;
  539. std::vector<std::string> specs = {"nodes=X:x", "parent=g->X", "fvs_direct=g:x"};
  540. std::map<std::string, int> size_list;
  541. return;
  542. }
  543. TEST_F(TestManager, test_calls) {
  544. std::vector<std::shared_ptr<Stage>> stages;
  545. std::vector<std::string> specs = {"parents=g->X; h->X", "children=X:g,h", "scopes=X:X,g,h; g:g; h:h",
  546. "fvs_direct=h:a", "fvs_total=h:a; g:h"};
  547. std::map<std::string, int> size_list;
  548. return;
  549. }
  550. TEST_F(TestManager, test_unused_param) {
  551. std::vector<std::shared_ptr<Stage>> stages;
  552. std::vector<std::string> specs = {"nodes=X:x,y"};
  553. std::map<std::string, int> size_list;
  554. }
  555. TEST_F(TestManager, test_cannot_replace_return) {
  556. FuncGraphPtr fg = getPyFun("test_cannot_replace_return");
  557. ASSERT_NE(fg, nullptr);
  558. auto mng = Manage(fg);
  559. ASSERT_EQ(fg->manager(), mng);
  560. ASSERT_NE(mng, nullptr);
  561. ASSERT_GT(fg->parameters().size(), 0);
  562. ASSERT_FALSE(mng->Replace(fg->get_return(), fg->parameters()[0]));
  563. }
  564. TEST_F(TestManager, test_weak_manager) {
  565. FuncGraphPtr fg = getPyFun("ir_get_fn");
  566. auto mng1 = MakeManager({fg}, false);
  567. ASSERT_EQ(fg->manager(), nullptr);
  568. auto mng2 = MakeManager({fg}, true);
  569. ASSERT_EQ(fg->manager(), mng2);
  570. auto mng3 = MakeManager({fg}, false);
  571. ASSERT_EQ(fg->manager(), mng2);
  572. }
  573. TEST_F(TestManager, test_drop_root) {
  574. FuncGraphPtr fg = getPyFun("ir_get_fn");
  575. auto mng = Manage(fg);
  576. const FuncGraphToAnfNodeMap& nodes = mng->nodes();
  577. ASSERT_TRUE(nodes.find(fg) != nodes.end());
  578. FuncGraphSet s;
  579. s.add(fg);
  580. mng->MaybeDropFuncGraphs(s);
  581. ASSERT_TRUE(nodes.find(fg) != nodes.end());
  582. }
  583. TEST_F(TestManager, test_keep_roots) {
  584. FuncGraphPtr fg1 = getPyFun("ir_get_fn");
  585. FuncGraphPtr fg2 = getPyFun("test_cannot_replace_return");
  586. auto mng = Manage(fg1);
  587. ASSERT_EQ(mng->func_graphs().size(), (size_t)1);
  588. ASSERT_TRUE(mng->func_graphs().contains(fg1));
  589. mng->AddFuncGraph(fg2);
  590. ASSERT_EQ(mng->func_graphs().size(), 2);
  591. ASSERT_TRUE(mng->func_graphs().contains(fg2));
  592. mng->KeepRoots();
  593. ASSERT_EQ(mng->func_graphs().size(), 1);
  594. ASSERT_TRUE(mng->func_graphs().contains(fg1));
  595. mng->KeepRoots({fg2});
  596. ASSERT_EQ(mng->func_graphs().size(), 1);
  597. ASSERT_TRUE(mng->func_graphs().contains(fg2));
  598. }
  599. TEST_F(TestManager, test_keep_roots_recursion) {
  600. return;
  601. FuncGraphPtr fg = getPyFun("test_keep_roots_recursion");
  602. ASSERT_NE(fg, nullptr);
  603. auto mng = Manage(fg);
  604. parse::ResolveAll(mng);
  605. ASSERT_NE(mng, nullptr);
  606. ASSERT_EQ(mng->func_graphs().size(), 4);
  607. ASSERT_GT(fg->parameters().size(), 0);
  608. mng->Replace(fg->output(), fg->parameters()[0]);
  609. ASSERT_EQ(mng->func_graphs().size(), 3);
  610. mng->KeepRoots();
  611. ASSERT_EQ(mng->func_graphs().size(), 1);
  612. }
  613. } // namespace mindspore