You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

manager_test.cc 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/common_test.h"
  17. #include "common/py_func_graph_fetcher.h"
  18. #include "ir/dtype.h"
  19. #include "ir/manager.h"
  20. #include "ir/func_graph_cloner.h"
  21. #include "pipeline/parse/parse.h"
  22. #include "operator/ops.h"
  23. #include "utils/log_adapter.h"
  24. #include "debug/draw.h"
  25. #include "debug/label.h"
  26. #include "./common.h"
  27. namespace mindspore {
  28. namespace {
  29. std::vector<std::string> SplitString(std::string str, std::string pattern) {
  30. std::string::size_type pos;
  31. std::vector<std::string> result;
  32. str += pattern;
  33. std::string::size_type size = str.size();
  34. for (std::string::size_type i = 0; i < size; ++i) {
  35. pos = str.find(pattern, i);
  36. if (pos < size) {
  37. std::string s = str.substr(i, pos - i);
  38. result.push_back(s);
  39. i = pos + pattern.size() - 1;
  40. }
  41. }
  42. return result;
  43. }
  44. } // namespace
  45. using std::dynamic_pointer_cast;
  46. using TodoList = std::vector<std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>>;
  47. using TodoListItem = std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>;
  48. class NestingSpecs;
  49. class Stage {
  50. public:
  51. explicit Stage(std::vector<std::string> specs) {
  52. for (auto arg : specs) {
  53. auto spec = SplitString(arg, "=");
  54. if (spec.size() <= 1) {
  55. continue;
  56. }
  57. std::shared_ptr<NestingSpecs> nesting = std::make_shared<NestingSpecs>(this, spec[1]);
  58. specs_[ToFullString(spec[0])] = nesting;
  59. }
  60. }
  61. ~Stage() {}
  62. std::map<std::string, std::string>& subs() { return subs_; }
  63. void set_subs(const std::map<std::string, std::string>& subs) { subs_ = subs; }
  64. private:
  65. std::string ToFullString(std::string s) {
  66. if (s.find("fv") != std::string::npos) {
  67. s = s.replace(s.find("fv"), 2, "free_variable");
  68. }
  69. if (s.find("deps") != std::string::npos) {
  70. s = s.replace(s.find("deps"), 4, "dependencies");
  71. }
  72. return s;
  73. }
  74. std::map<std::string, std::shared_ptr<NestingSpecs>> specs_;
  75. std::map<std::string, std::string> subs_;
  76. };
  77. class NestingSpecs {
  78. public:
  79. NestingSpecs(Stage* stage, std::string specs) : stage_(stage) { ParseSpecs(specs); }
  80. ~NestingSpecs() {}
  81. std::string Name(Any node) {
  82. std::string name = label_manage::Label(node.cast<AnfNodePtr>()->debug_info());
  83. if (stage_->subs().find(name) != stage_->subs().end()) {
  84. return stage_->subs()[name];
  85. }
  86. return name;
  87. }
  88. void Check(std::shared_ptr<FuncGraphAnalysis> results) {
  89. if (expected_.empty() && expected_recursive_.empty()) {
  90. return;
  91. }
  92. auto parent = dynamic_pointer_cast<ParentComputer>(results);
  93. if (parent != nullptr) {
  94. CheckParent(parent);
  95. return;
  96. }
  97. auto recursive = dynamic_pointer_cast<RecursiveComputer>(results);
  98. if (recursive != nullptr) {
  99. CheckRecursive(recursive);
  100. return;
  101. }
  102. auto counter_g = dynamic_pointer_cast<CounterFuncGraphCollector>(results);
  103. if (counter_g != nullptr) {
  104. CheckGraphCounter(counter_g);
  105. return;
  106. }
  107. auto counter_p = dynamic_pointer_cast<CounterAnfNodeCollector<AnfNodePtr>>(results);
  108. if (counter_p != nullptr) {
  109. CheckAnfNodeCounter(counter_p);
  110. return;
  111. }
  112. auto counter_pair = dynamic_pointer_cast<CounterAnfNodeCollector<CNodeIndexPairPtr>>(results);
  113. if (counter_pair != nullptr) {
  114. CheckCNodeIndexPairCounter(counter_pair);
  115. return;
  116. }
  117. auto nodes = dynamic_pointer_cast<NodesCollector>(results);
  118. if (nodes != nullptr) {
  119. CheckNodes(nodes);
  120. return;
  121. }
  122. }
  123. private:
  124. void ParseSpecs(std::string specs) {
  125. if (specs.empty()) {
  126. return;
  127. }
  128. std::vector<std::string> str_list = SplitString(specs, ";");
  129. for (auto spec : str_list) {
  130. spec.erase(0, spec.find_first_not_of(" "));
  131. spec.erase(spec.find_last_not_of(" ") + 1);
  132. if (spec.empty()) {
  133. continue;
  134. }
  135. if (spec.find("->") != std::string::npos) {
  136. auto substr = SplitString(spec, "->");
  137. ASSERT_GT(substr.size(), 1);
  138. auto key = substr[0];
  139. auto value = substr[1];
  140. if (!value.empty()) {
  141. expected_[key] = {value};
  142. }
  143. } else if (spec.find(":") != std::string::npos) {
  144. auto substr = SplitString(spec, ":");
  145. ASSERT_GT(substr.size(), 1);
  146. auto key = substr[0];
  147. auto values = SplitString(substr[1], ",");
  148. std::set<std::string> values_set(values.begin(), values.end());
  149. if (!values_set.empty()) {
  150. expected_[key] = values_set;
  151. }
  152. } else {
  153. expected_recursive_[spec] = true;
  154. }
  155. }
  156. }
  157. void CheckParent(std::shared_ptr<ParentComputer> results) {
  158. std::map<std::string, std::set<std::string>> clean_results;
  159. for (auto& iter : results->parent_analysis()) {
  160. auto key = iter.first;
  161. auto value = iter.second;
  162. if (key == nullptr) {
  163. continue;
  164. }
  165. std::string k = Name(key);
  166. std::set<std::string> v;
  167. if (value != nullptr && !Name(value).empty()) {
  168. v.insert(Name(value));
  169. }
  170. if (!v.empty()) {
  171. clean_results[k] = v;
  172. }
  173. }
  174. ASSERT_EQ(clean_results, expected_);
  175. }
  176. void CheckNodes(std::shared_ptr<NodesCollector> results) {
  177. std::map<std::string, std::set<std::string>> clean_results;
  178. for (auto& iter : results->nodes_analysis()) {
  179. auto key = iter.first;
  180. auto value = iter.second;
  181. if (key == nullptr) {
  182. continue;
  183. }
  184. std::string k = Name(key);
  185. std::set<std::string> v;
  186. for (auto& node : value) {
  187. if (!node->isa<CNode>() && !Name(node).empty()) {
  188. v.insert(Name(node));
  189. }
  190. }
  191. if (!v.empty()) {
  192. clean_results[k] = v;
  193. }
  194. }
  195. ASSERT_EQ(clean_results, expected_);
  196. }
  197. // Add CheckNesting function
  198. void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) {
  199. std::map<std::string, std::set<std::string>> clean_results;
  200. for (auto& iter : results->count_nodes_map()) {
  201. auto key = iter.first;
  202. auto value = iter.second;
  203. if (key == nullptr) {
  204. continue;
  205. }
  206. std::string k = Name(key);
  207. std::set<std::string> v;
  208. for (auto& node : value) {
  209. auto fg = node.first;
  210. if (!Name(fg).empty()) {
  211. v.insert(Name(fg));
  212. }
  213. }
  214. if (!v.empty()) {
  215. clean_results[k] = v;
  216. }
  217. }
  218. ASSERT_EQ(clean_results, expected_);
  219. }
  220. void CheckCNodeIndexPairCounter(std::shared_ptr<CounterAnfNodeCollector<CNodeIndexPairPtr>> results) {
  221. std::map<std::string, std::set<std::string>> clean_results;
  222. for (auto& iter : results->count_nodes_map()) {
  223. auto key = iter.first;
  224. auto value = iter.second;
  225. if (key == nullptr) {
  226. continue;
  227. }
  228. std::string k = Name(key);
  229. std::set<std::string> v;
  230. for (auto& node : value) {
  231. auto fg = node.first->first;
  232. if (!Name(fg).empty()) {
  233. v.insert(Name(fg));
  234. }
  235. }
  236. if (!v.empty()) {
  237. clean_results[k] = v;
  238. }
  239. }
  240. ASSERT_EQ(clean_results, expected_);
  241. }
  242. void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) {
  243. std::map<std::string, std::set<std::string>> clean_results;
  244. for (auto& iter : results->count_func_graphs_map()) {
  245. auto key = iter.first;
  246. auto value = iter.second;
  247. if (key == nullptr) {
  248. continue;
  249. }
  250. std::string k = Name(key);
  251. std::set<std::string> v;
  252. for (auto& node : value) {
  253. auto fg = node.first;
  254. if (!Name(fg).empty()) {
  255. v.insert(Name(fg));
  256. }
  257. }
  258. if (!v.empty()) {
  259. clean_results[k] = v;
  260. }
  261. }
  262. ASSERT_EQ(clean_results, expected_);
  263. }
  264. void CheckRecursive(std::shared_ptr<RecursiveComputer> results) {
  265. std::map<std::string, bool> clean_results;
  266. for (auto iter = results->recursive_analysis().begin(); iter != results->recursive_analysis().end(); ++iter) {
  267. auto key = iter->first;
  268. auto value = iter->second;
  269. if (key == nullptr) {
  270. continue;
  271. }
  272. std::string k = Name(key);
  273. clean_results[k] = value;
  274. }
  275. ASSERT_EQ(clean_results, expected_recursive_);
  276. }
  277. private:
  278. Stage* stage_;
  279. std::map<std::string, std::set<std::string>> expected_;
  280. std::map<std::string, bool> expected_recursive_;
  281. };
  282. bool CheckUsers(std::shared_ptr<FuncGraphManager> manager) {
  283. for (auto node : manager->all_nodes()) {
  284. if (node->isa<CNode>()) {
  285. auto& inputs = node->cast<CNodePtr>()->inputs();
  286. for (size_t i = 0; i < inputs.size(); ++i) {
  287. auto inp = inputs[i];
  288. if (!manager->all_nodes().contains(inp)) {
  289. return false;
  290. }
  291. if (manager->node_users().find(inp) != manager->node_users().end()) {
  292. auto users = manager->node_users()[inp];
  293. if (!users.contains(make_pair(node, i))) {
  294. return false;
  295. }
  296. }
  297. }
  298. }
  299. if (manager->node_users().find(node) != manager->node_users().end()) {
  300. auto users = manager->node_users()[node];
  301. for (auto iter = users.begin(); iter != users.end(); ++iter) {
  302. auto node2 = iter->first;
  303. auto key = iter->second;
  304. if (!manager->all_nodes().contains(node2)) {
  305. return false;
  306. }
  307. if (node2->cast<CNodePtr>()->input(key) != node) {
  308. return false;
  309. }
  310. }
  311. }
  312. }
  313. return true;
  314. }
  315. class TestManager : public UT::Common {
  316. public:
  317. TestManager() : getPyFun("gtest_input.ir.manager_test") {}
  318. void CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng);
  319. public:
  320. std::vector<PrimitivePtr> swaps;
  321. UT::PyFuncGraphFetcher getPyFun;
  322. };
  323. FuncGraphPtr MakeFuncGraph(PrimitivePtr prim) {
  324. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  325. ParameterPtr x = func_graph->add_parameter();
  326. ParameterPtr y = func_graph->add_parameter();
  327. std::vector<AnfNodePtr> inputs;
  328. inputs.push_back(NewValueNode(prim));
  329. inputs.push_back(x);
  330. inputs.push_back(y);
  331. CNodePtr cnode_add = func_graph->NewCNode(inputs);
  332. inputs.clear();
  333. inputs.push_back(NewValueNode(prim::kPrimReturn));
  334. inputs.push_back(cnode_add);
  335. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  336. func_graph->set_return(cnode_return);
  337. return func_graph;
  338. }
  339. std::vector<FuncGraphPtr> MakeNestedGraph() {
  340. /*
  341. *def f(x):
  342. * def g():
  343. * return x
  344. * return g
  345. */
  346. FuncGraphPtr f = std::make_shared<FuncGraph>();
  347. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  348. ParameterPtr x = f->add_parameter();
  349. std::vector<AnfNodePtr> inputs;
  350. inputs.push_back(NewValueNode(fg));
  351. inputs.push_back(NewValueNode(prim::kPrimReturn));
  352. CNodePtr cnode_f = f->NewCNode(inputs);
  353. f->set_return(cnode_f);
  354. inputs.clear();
  355. inputs.push_back(NewValueNode(prim::kPrimReturn));
  356. inputs.push_back(x);
  357. CNodePtr cnode_g = fg->NewCNode(inputs);
  358. fg->set_return(cnode_g);
  359. std::vector<FuncGraphPtr> result = {f, fg};
  360. return result;
  361. }
  362. std::vector<FuncGraphPtr> MakeNestedGraph2() {
  363. /* build a closure func_graph */
  364. /*
  365. *def foo(x, y):
  366. * def bar(x1):
  367. * return x1 + y
  368. * return bar(x)
  369. */
  370. FuncGraphPtr graph_foo = std::make_shared<FuncGraph>();
  371. ParameterPtr x = graph_foo->add_parameter();
  372. ParameterPtr y = graph_foo->add_parameter();
  373. std::vector<AnfNodePtr> inputs;
  374. // build func_graph bar
  375. FuncGraphPtr graph_bar = std::make_shared<FuncGraph>();
  376. ParameterPtr x1 = graph_bar->add_parameter();
  377. inputs.clear();
  378. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  379. inputs.push_back(x1);
  380. inputs.push_back(y);
  381. CNodePtr cnode_add = graph_bar->NewCNode(inputs);
  382. inputs.clear();
  383. inputs.push_back(NewValueNode(prim::kPrimReturn));
  384. inputs.push_back(cnode_add);
  385. CNodePtr cnode_return = graph_bar->NewCNode(inputs);
  386. graph_bar->set_return(cnode_return);
  387. // build func_graph foo
  388. inputs.clear();
  389. inputs.push_back(NewValueNode(graph_bar));
  390. inputs.push_back(x);
  391. CNodePtr cnode_graph_bar = graph_foo->NewCNode(inputs);
  392. inputs.clear();
  393. inputs.push_back(NewValueNode(prim::kPrimReturn));
  394. inputs.push_back(cnode_graph_bar);
  395. cnode_return = graph_foo->NewCNode(inputs);
  396. graph_foo->set_return(cnode_return);
  397. std::vector<FuncGraphPtr> result = {graph_foo, graph_bar};
  398. return result;
  399. }
  400. // Add TestManager::CheckManager function to checkout the result
  401. void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) {
  402. auto size = mng->func_graphs().size();
  403. ASSERT_EQ(size + 1, mng->nodes().size());
  404. ASSERT_EQ(size, mng->free_variables_total().size());
  405. ASSERT_EQ(size, mng->valuenodes().size());
  406. ASSERT_EQ(size, mng->free_variables_direct().size());
  407. ASSERT_EQ(size, mng->func_graph_cnodes_index().size());
  408. ASSERT_EQ(size, mng->func_graph_parents_direct().size());
  409. ASSERT_EQ(size, mng->func_graphs_used().size());
  410. }
  411. TEST_F(TestManager, test_scalar_add_manual) {
  412. auto prim_scalar_add = prim::kPrimScalarAdd;
  413. FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
  414. auto mng = Manage(func_graph);
  415. }
  416. TEST_F(TestManager, test_scalar_replace) {
  417. auto prim_scalar_add = prim::kPrimScalarAdd;
  418. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  419. ParameterPtr x = func_graph->add_parameter();
  420. ParameterPtr y = func_graph->add_parameter();
  421. std::vector<AnfNodePtr> inputs;
  422. inputs.push_back(NewValueNode(prim_scalar_add));
  423. inputs.push_back(x);
  424. inputs.push_back(y);
  425. CNodePtr cnode_add = func_graph->NewCNode(inputs);
  426. inputs.clear();
  427. inputs.push_back(NewValueNode(prim::kPrimReturn));
  428. inputs.push_back(cnode_add);
  429. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  430. func_graph->set_return(cnode_return);
  431. auto mng = Manage(func_graph);
  432. std::cout << "start " << x->ToString() << std::endl;
  433. mng->Replace(cnode_add, x);
  434. }
  435. TEST_F(TestManager, test_nested_manual) {
  436. auto graphs = MakeNestedGraph();
  437. auto f = graphs[0];
  438. auto g = graphs[1];
  439. auto mng = Manage(f);
  440. ASSERT_EQ(6, mng->all_nodes().size());
  441. ASSERT_EQ(2, mng->func_graphs().size());
  442. ASSERT_EQ(4, mng->node_users().size());
  443. ASSERT_EQ(1, mng->roots().size());
  444. CheckAnalysisSize(mng);
  445. auto nodes = mng->nodes();
  446. ASSERT_EQ(3, nodes[nullptr].size());
  447. ASSERT_EQ(2, nodes[f].size());
  448. ASSERT_EQ(1, nodes[g].size());
  449. auto users = mng->node_users();
  450. for (auto& iter : users) {
  451. ASSERT_EQ(1, iter.second.size());
  452. }
  453. auto graphs_used = mng->func_graphs_used();
  454. ASSERT_EQ(1, graphs_used[f].size());
  455. ASSERT_EQ(0, graphs_used[g].size());
  456. auto fv_direct = mng->free_variables_direct();
  457. ASSERT_EQ(0, fv_direct[f].size());
  458. ASSERT_EQ(1, fv_direct[g].size());
  459. auto fv_total = mng->free_variables_total();
  460. ASSERT_EQ(0, fv_total[f].size());
  461. ASSERT_EQ(1, fv_total[g].size());
  462. auto cnodes = mng->func_graph_cnodes_index();
  463. ASSERT_EQ(0, cnodes[f].size());
  464. ASSERT_EQ(1, cnodes[g].size());
  465. }
  466. TEST_F(TestManager, test_deep_nested2_manual) {
  467. // create parser
  468. FuncGraphPtr func_graph = getPyFun("test_custom");
  469. return;
  470. // parse ast to func graph
  471. FuncGraphPtr gfn = BasicClone(func_graph);
  472. if (gfn == nullptr) {
  473. return;
  474. }
  475. auto mng = Manage(gfn);
  476. ASSERT_EQ(3, mng->func_graphs().size());
  477. ASSERT_EQ(1, mng->roots().size());
  478. ASSERT_EQ(4, mng->nodes().size());
  479. ASSERT_EQ(20, mng->all_nodes().size());
  480. ASSERT_EQ(25, mng->node_users().size());
  481. CheckAnalysisSize(mng);
  482. }
  483. TEST_F(TestManager, test_deep_nested_manual) {
  484. FuncGraphPtr f = std::make_shared<FuncGraph>();
  485. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  486. FuncGraphPtr h = std::make_shared<FuncGraph>();
  487. ParameterPtr x = f->add_parameter();
  488. ParameterPtr y = f->add_parameter();
  489. ParameterPtr z = f->add_parameter();
  490. std::vector<AnfNodePtr> inputs;
  491. inputs.push_back(NewValueNode(fg));
  492. inputs.push_back(x);
  493. inputs.push_back(y);
  494. CNodePtr cnode_1 = f->NewCNode(inputs);
  495. inputs.clear();
  496. inputs.push_back(cnode_1);
  497. inputs.push_back(NewValueNode(prim::kPrimReturn));
  498. CNodePtr cnode_0 = f->NewCNode(inputs);
  499. f->set_return(cnode_0);
  500. ParameterPtr x1 = fg->add_parameter();
  501. ParameterPtr y1 = fg->add_parameter();
  502. inputs.clear();
  503. inputs.push_back(NewValueNode(h));
  504. inputs.push_back(x1);
  505. CNodePtr cnode_3 = fg->NewCNode(inputs);
  506. inputs.clear();
  507. inputs.push_back(cnode_3);
  508. inputs.push_back(NewValueNode(prim::kPrimReturn));
  509. CNodePtr cnode_2 = fg->NewCNode(inputs);
  510. fg->set_return(cnode_2);
  511. ParameterPtr x2 = h->add_parameter();
  512. inputs.clear();
  513. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  514. inputs.push_back(x2);
  515. inputs.push_back(y1);
  516. CNodePtr cnode_6 = h->NewCNode(inputs);
  517. inputs.clear();
  518. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  519. inputs.push_back(z);
  520. inputs.push_back(cnode_6);
  521. CNodePtr cnode_5 = h->NewCNode(inputs);
  522. inputs.clear();
  523. inputs.push_back(cnode_5);
  524. inputs.push_back(NewValueNode(prim::kPrimReturn));
  525. CNodePtr cnode_4 = h->NewCNode(inputs);
  526. h->set_return(cnode_4);
  527. auto mng = Manage(f);
  528. ASSERT_EQ(3, mng->func_graphs().size());
  529. ASSERT_EQ(1, mng->roots().size());
  530. ASSERT_EQ(4, mng->nodes().size());
  531. ASSERT_EQ(20, mng->all_nodes().size());
  532. CheckAnalysisSize(mng);
  533. }
  534. TEST_F(TestManager, test_parent1_manual) {
  535. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  536. Parameter param(fg);
  537. std::vector<AnfNodePtr> params;
  538. CNodePtr app = std::make_shared<CNode>(params, fg);
  539. fg->set_return(app);
  540. fg->set_parameters(params);
  541. std::shared_ptr<FuncGraphManager> manager = MakeManager();
  542. manager->AddFuncGraph(fg, true);
  543. FuncGraphPtr p = fg->parent();
  544. assert(p == nullptr);
  545. }
  546. TEST_F(TestManager, test_parent_manual) {
  547. auto prim_scalar_add = prim::kPrimScalarAdd;
  548. FuncGraphPtr fg = MakeFuncGraph(prim_scalar_add);
  549. std::shared_ptr<FuncGraphManager> manager = MakeManager();
  550. manager->AddFuncGraph(fg);
  551. FuncGraphPtr p = fg->parent();
  552. assert(p == nullptr);
  553. }
  554. TEST_F(TestManager, test_flat) {
  555. std::vector<std::shared_ptr<Stage>> stages;
  556. std::vector<std::string> specs = {"nodes=X:x", "parents=", "fvs_direct="};
  557. std::map<std::string, int> size_list;
  558. size_list["nodes"] = 2;
  559. }
  560. TEST_F(TestManager, test_nested) {
  561. std::vector<std::shared_ptr<Stage>> stages;
  562. std::vector<std::string> specs = {"nodes=X:x", "parent=g->X", "fvs_direct=g:x"};
  563. std::map<std::string, int> size_list;
  564. return;
  565. }
  566. TEST_F(TestManager, test_calls) {
  567. std::vector<std::shared_ptr<Stage>> stages;
  568. std::vector<std::string> specs = {"parents=g->X; h->X", "children=X:g,h", "scopes=X:X,g,h; g:g; h:h",
  569. "fvs_direct=h:a", "fvs_total=h:a; g:h"};
  570. std::map<std::string, int> size_list;
  571. return;
  572. }
  573. TEST_F(TestManager, test_unused_param) {
  574. std::vector<std::shared_ptr<Stage>> stages;
  575. std::vector<std::string> specs = {"nodes=X:x,y"};
  576. std::map<std::string, int> size_list;
  577. }
  578. TEST_F(TestManager, test_cannot_replace_return) {
  579. FuncGraphPtr fg = getPyFun("test_cannot_replace_return");
  580. ASSERT_NE(fg, nullptr);
  581. auto mng = Manage(fg);
  582. ASSERT_EQ(fg->manager(), mng);
  583. ASSERT_NE(mng, nullptr);
  584. ASSERT_GT(fg->parameters().size(), 0);
  585. ASSERT_FALSE(mng->Replace(fg->get_return(), fg->parameters()[0]));
  586. }
  587. TEST_F(TestManager, test_weak_manager) {
  588. FuncGraphPtr fg = getPyFun("ir_get_fn");
  589. auto mng1 = MakeManager({fg}, false);
  590. ASSERT_EQ(fg->manager(), nullptr);
  591. auto mng2 = MakeManager({fg}, true);
  592. ASSERT_EQ(fg->manager(), mng2);
  593. auto mng3 = MakeManager({fg}, false);
  594. ASSERT_EQ(fg->manager(), mng2);
  595. }
  596. TEST_F(TestManager, test_drop_root) {
  597. FuncGraphPtr fg = getPyFun("ir_get_fn");
  598. auto mng = Manage(fg);
  599. const FuncGraphToAnfNodeMap& nodes = mng->nodes();
  600. ASSERT_TRUE(nodes.find(fg) != nodes.end());
  601. FuncGraphSet s;
  602. s.add(fg);
  603. mng->MaybeDropFuncGraphs(s);
  604. ASSERT_TRUE(nodes.find(fg) != nodes.end());
  605. }
  606. TEST_F(TestManager, test_keep_roots) {
  607. FuncGraphPtr fg1 = getPyFun("ir_get_fn");
  608. FuncGraphPtr fg2 = getPyFun("test_cannot_replace_return");
  609. auto mng = Manage(fg1);
  610. ASSERT_EQ(mng->func_graphs().size(), (size_t)1);
  611. ASSERT_TRUE(mng->func_graphs().contains(fg1));
  612. mng->AddFuncGraph(fg2);
  613. ASSERT_EQ(mng->func_graphs().size(), 2);
  614. ASSERT_TRUE(mng->func_graphs().contains(fg2));
  615. mng->KeepRoots();
  616. ASSERT_EQ(mng->func_graphs().size(), 1);
  617. ASSERT_TRUE(mng->func_graphs().contains(fg1));
  618. mng->KeepRoots({fg2});
  619. ASSERT_EQ(mng->func_graphs().size(), 1);
  620. ASSERT_TRUE(mng->func_graphs().contains(fg2));
  621. }
  622. TEST_F(TestManager, test_keep_roots_recursion) {
  623. return;
  624. FuncGraphPtr fg = getPyFun("test_keep_roots_recursion");
  625. ASSERT_NE(fg, nullptr);
  626. auto mng = Manage(fg);
  627. parse::ResolveAll(mng);
  628. ASSERT_NE(mng, nullptr);
  629. ASSERT_EQ(mng->func_graphs().size(), 4);
  630. ASSERT_GT(fg->parameters().size(), 0);
  631. mng->Replace(fg->output(), fg->parameters()[0]);
  632. ASSERT_EQ(mng->func_graphs().size(), 3);
  633. mng->KeepRoots();
  634. ASSERT_EQ(mng->func_graphs().size(), 1);
  635. }
  636. } // namespace mindspore