You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

manager_test.cc 18 kB


  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/common_test.h"
  17. #include "common/py_func_graph_fetcher.h"
  18. #include "ir/dtype.h"
  19. #include "ir/manager.h"
  20. #include "ir/func_graph_cloner.h"
  21. #include "pipeline/jit/parse/parse.h"
  22. #include "frontend/operator/ops.h"
  23. #include "utils/log_adapter.h"
  24. #include "debug/draw.h"
  25. #include "debug/label.h"
  26. #include "./common.h"
  27. namespace mindspore {
  28. namespace {
  29. std::vector<std::string> SplitString(std::string str, std::string pattern) {
  30. std::string::size_type pos;
  31. std::vector<std::string> result;
  32. str += pattern;
  33. std::string::size_type size = str.size();
  34. for (std::string::size_type i = 0; i < size; ++i) {
  35. pos = str.find(pattern, i);
  36. if (pos < size) {
  37. std::string s = str.substr(i, pos - i);
  38. result.push_back(s);
  39. i = pos + pattern.size() - 1;
  40. }
  41. }
  42. return result;
  43. }
  44. } // namespace
  45. using std::dynamic_pointer_cast;
  46. using TodoList = std::vector<std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>>;
  47. using TodoListItem = std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>;
  48. class NestingSpecs;
  49. class Stage {
  50. public:
  51. explicit Stage(std::vector<std::string> specs) {
  52. for (auto arg : specs) {
  53. auto spec = SplitString(arg, "=");
  54. if (spec.size() <= 1) {
  55. continue;
  56. }
  57. std::shared_ptr<NestingSpecs> nesting = std::make_shared<NestingSpecs>(this, spec[1]);
  58. specs_[ToFullString(spec[0])] = nesting;
  59. }
  60. }
  61. ~Stage() {}
  62. std::map<std::string, std::string>& subs() { return subs_; }
  63. void set_subs(const std::map<std::string, std::string>& subs) { subs_ = subs; }
  64. private:
  65. std::string ToFullString(std::string s) {
  66. if (s.find("fv") != std::string::npos) {
  67. s = s.replace(s.find("fv"), 2, "free_variable");
  68. }
  69. if (s.find("deps") != std::string::npos) {
  70. s = s.replace(s.find("deps"), 4, "dependencies");
  71. }
  72. return s;
  73. }
  74. std::map<std::string, std::shared_ptr<NestingSpecs>> specs_;
  75. std::map<std::string, std::string> subs_;
  76. };
  77. class NestingSpecs {
  78. public:
  79. NestingSpecs(Stage* stage, std::string specs) : stage_(stage) { ParseSpecs(specs); }
  80. ~NestingSpecs() {}
  81. std::string Name(Any node) {
  82. std::string name = label_manage::Label(node.cast<AnfNodePtr>()->debug_info());
  83. if (stage_->subs().find(name) != stage_->subs().end()) {
  84. return stage_->subs()[name];
  85. }
  86. return name;
  87. }
  88. void Check(std::shared_ptr<DepComputer> results) {
  89. if (expected_.empty() && expected_recursive_.empty()) {
  90. return;
  91. }
  92. auto parent = dynamic_pointer_cast<ParentComputer>(results);
  93. if (parent != nullptr) {
  94. CheckParent(parent);
  95. return;
  96. }
  97. auto recursive = dynamic_pointer_cast<RecursiveComputer>(results);
  98. if (recursive != nullptr) {
  99. CheckRecursive(recursive);
  100. return;
  101. }
  102. }
  103. private:
  104. void ParseSpecs(std::string specs) {
  105. if (specs.empty()) {
  106. return;
  107. }
  108. std::vector<std::string> str_list = SplitString(specs, ";");
  109. for (auto spec : str_list) {
  110. spec.erase(0, spec.find_first_not_of(" "));
  111. spec.erase(spec.find_last_not_of(" ") + 1);
  112. if (spec.empty()) {
  113. continue;
  114. }
  115. if (spec.find("->") != std::string::npos) {
  116. auto substr = SplitString(spec, "->");
  117. ASSERT_GT(substr.size(), 1);
  118. auto key = substr[0];
  119. auto value = substr[1];
  120. if (!value.empty()) {
  121. expected_[key] = {value};
  122. }
  123. } else if (spec.find(":") != std::string::npos) {
  124. auto substr = SplitString(spec, ":");
  125. ASSERT_GT(substr.size(), 1);
  126. auto key = substr[0];
  127. auto values = SplitString(substr[1], ",");
  128. std::set<std::string> values_set(values.begin(), values.end());
  129. if (!values_set.empty()) {
  130. expected_[key] = values_set;
  131. }
  132. } else {
  133. expected_recursive_[spec] = true;
  134. }
  135. }
  136. }
  137. void CheckParent(std::shared_ptr<ParentComputer> results) {
  138. std::map<std::string, std::set<std::string>> clean_results;
  139. for (auto& iter : results->parent_analysis()) {
  140. auto key = iter.first;
  141. auto value = iter.second;
  142. if (key == nullptr) {
  143. continue;
  144. }
  145. std::string k = Name(key);
  146. std::set<std::string> v;
  147. if (value != nullptr && !Name(value).empty()) {
  148. v.insert(Name(value));
  149. }
  150. if (!v.empty()) {
  151. clean_results[k] = v;
  152. }
  153. }
  154. ASSERT_EQ(clean_results, expected_);
  155. }
  156. void CheckRecursive(std::shared_ptr<RecursiveComputer> results) {
  157. std::map<std::string, bool> clean_results;
  158. for (auto iter = results->recursive_analysis().begin(); iter != results->recursive_analysis().end(); ++iter) {
  159. auto key = iter->first;
  160. auto value = iter->second;
  161. if (key == nullptr) {
  162. continue;
  163. }
  164. std::string k = Name(key);
  165. clean_results[k] = value;
  166. }
  167. ASSERT_EQ(clean_results, expected_recursive_);
  168. }
  169. private:
  170. Stage* stage_;
  171. std::map<std::string, std::set<std::string>> expected_;
  172. std::map<std::string, bool> expected_recursive_;
  173. };
  174. bool CheckUsers(std::shared_ptr<FuncGraphManager> manager) {
  175. for (auto node : manager->all_nodes()) {
  176. if (node->isa<CNode>()) {
  177. auto& inputs = node->cast<CNodePtr>()->inputs();
  178. for (size_t i = 0; i < inputs.size(); ++i) {
  179. auto inp = inputs[i];
  180. if (!manager->all_nodes().contains(inp)) {
  181. return false;
  182. }
  183. if (manager->node_users().find(inp) != manager->node_users().end()) {
  184. auto users = manager->node_users()[inp];
  185. if (!users.contains(make_pair(node, i))) {
  186. return false;
  187. }
  188. }
  189. }
  190. }
  191. if (manager->node_users().find(node) != manager->node_users().end()) {
  192. auto users = manager->node_users()[node];
  193. for (auto iter = users.begin(); iter != users.end(); ++iter) {
  194. auto node2 = iter->first;
  195. auto key = iter->second;
  196. if (!manager->all_nodes().contains(node2)) {
  197. return false;
  198. }
  199. if (node2->cast<CNodePtr>()->input(key) != node) {
  200. return false;
  201. }
  202. }
  203. }
  204. }
  205. return true;
  206. }
  207. class TestManager : public UT::Common {
  208. public:
  209. TestManager() : getPyFun("gtest_input.ir.manager_test") {}
  210. void CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng);
  211. public:
  212. std::vector<PrimitivePtr> swaps;
  213. UT::PyFuncGraphFetcher getPyFun;
  214. };
  215. FuncGraphPtr MakeFuncGraph(PrimitivePtr prim) {
  216. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  217. ParameterPtr x = func_graph->add_parameter();
  218. ParameterPtr y = func_graph->add_parameter();
  219. std::vector<AnfNodePtr> inputs;
  220. inputs.push_back(NewValueNode(prim));
  221. inputs.push_back(x);
  222. inputs.push_back(y);
  223. CNodePtr cnode_add = func_graph->NewCNode(inputs);
  224. inputs.clear();
  225. inputs.push_back(NewValueNode(prim::kPrimReturn));
  226. inputs.push_back(cnode_add);
  227. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  228. func_graph->set_return(cnode_return);
  229. return func_graph;
  230. }
  231. std::vector<FuncGraphPtr> MakeNestedGraph() {
  232. /*
  233. *def f(x):
  234. * def g():
  235. * return x
  236. * return g
  237. */
  238. FuncGraphPtr f = std::make_shared<FuncGraph>();
  239. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  240. ParameterPtr x = f->add_parameter();
  241. std::vector<AnfNodePtr> inputs;
  242. inputs.push_back(NewValueNode(fg));
  243. inputs.push_back(NewValueNode(prim::kPrimReturn));
  244. CNodePtr cnode_f = f->NewCNode(inputs);
  245. f->set_return(cnode_f);
  246. inputs.clear();
  247. inputs.push_back(NewValueNode(prim::kPrimReturn));
  248. inputs.push_back(x);
  249. CNodePtr cnode_g = fg->NewCNode(inputs);
  250. fg->set_return(cnode_g);
  251. std::vector<FuncGraphPtr> result = {f, fg};
  252. return result;
  253. }
  254. std::vector<FuncGraphPtr> MakeNestedGraph2() {
  255. /* build a closure func_graph */
  256. /*
  257. *def foo(x, y):
  258. * def bar(x1):
  259. * return x1 + y
  260. * return bar(x)
  261. */
  262. FuncGraphPtr graph_foo = std::make_shared<FuncGraph>();
  263. ParameterPtr x = graph_foo->add_parameter();
  264. ParameterPtr y = graph_foo->add_parameter();
  265. std::vector<AnfNodePtr> inputs;
  266. // build func_graph bar
  267. FuncGraphPtr graph_bar = std::make_shared<FuncGraph>();
  268. ParameterPtr x1 = graph_bar->add_parameter();
  269. inputs.clear();
  270. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  271. inputs.push_back(x1);
  272. inputs.push_back(y);
  273. CNodePtr cnode_add = graph_bar->NewCNode(inputs);
  274. inputs.clear();
  275. inputs.push_back(NewValueNode(prim::kPrimReturn));
  276. inputs.push_back(cnode_add);
  277. CNodePtr cnode_return = graph_bar->NewCNode(inputs);
  278. graph_bar->set_return(cnode_return);
  279. // build func_graph foo
  280. inputs.clear();
  281. inputs.push_back(NewValueNode(graph_bar));
  282. inputs.push_back(x);
  283. CNodePtr cnode_graph_bar = graph_foo->NewCNode(inputs);
  284. inputs.clear();
  285. inputs.push_back(NewValueNode(prim::kPrimReturn));
  286. inputs.push_back(cnode_graph_bar);
  287. cnode_return = graph_foo->NewCNode(inputs);
  288. graph_foo->set_return(cnode_return);
  289. std::vector<FuncGraphPtr> result = {graph_foo, graph_bar};
  290. return result;
  291. }
  292. // Add TestManager::CheckManager function to checkout the result
  293. void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) {
  294. auto size = mng->func_graphs().size();
  295. ASSERT_EQ(size, mng->free_variables_total().size());
  296. }
  297. TEST_F(TestManager, test_scalar_add_manual) {
  298. auto prim_scalar_add = prim::kPrimScalarAdd;
  299. FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
  300. auto mng = Manage(func_graph);
  301. }
  302. TEST_F(TestManager, test_scalar_replace) {
  303. auto prim_scalar_add = prim::kPrimScalarAdd;
  304. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  305. ParameterPtr x = func_graph->add_parameter();
  306. ParameterPtr y = func_graph->add_parameter();
  307. std::vector<AnfNodePtr> inputs;
  308. inputs.push_back(NewValueNode(prim_scalar_add));
  309. inputs.push_back(x);
  310. inputs.push_back(y);
  311. CNodePtr cnode_add = func_graph->NewCNode(inputs);
  312. inputs.clear();
  313. inputs.push_back(NewValueNode(prim::kPrimReturn));
  314. inputs.push_back(cnode_add);
  315. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  316. func_graph->set_return(cnode_return);
  317. auto mng = Manage(func_graph);
  318. std::cout << "start " << x->ToString() << std::endl;
  319. mng->Replace(cnode_add, x);
  320. }
  321. TEST_F(TestManager, test_nested_manual) {
  322. auto graphs = MakeNestedGraph();
  323. auto f = graphs[0];
  324. auto g = graphs[1];
  325. auto mng = Manage(f);
  326. ASSERT_EQ(6, mng->all_nodes().size());
  327. ASSERT_EQ(2, mng->func_graphs().size());
  328. ASSERT_EQ(4, mng->node_users().size());
  329. ASSERT_EQ(1, mng->roots().size());
  330. CheckAnalysisSize(mng);
  331. ASSERT_EQ(2, f->nodes().size());
  332. ASSERT_EQ(1, g->nodes().size());
  333. auto users = mng->node_users();
  334. for (auto& iter : users) {
  335. ASSERT_EQ(1, iter.second.size());
  336. }
  337. ASSERT_EQ(1, f->func_graphs_used().size());
  338. ASSERT_EQ(0, g->func_graphs_used().size());
  339. ASSERT_EQ(0, f->free_variables().size());
  340. ASSERT_EQ(1, g->free_variables().size());
  341. auto fv_total = mng->free_variables_total();
  342. ASSERT_EQ(0, fv_total[f].size());
  343. ASSERT_EQ(1, fv_total[g].size());
  344. ASSERT_EQ(0, f->func_graph_cnodes_index().size());
  345. ASSERT_EQ(1, g->func_graph_cnodes_index().size());
  346. }
  347. TEST_F(TestManager, test_deep_nested2_manual) {
  348. // create parser
  349. FuncGraphPtr func_graph = getPyFun("test_custom");
  350. return;
  351. // parse ast to func graph
  352. FuncGraphPtr gfn = BasicClone(func_graph);
  353. if (gfn == nullptr) {
  354. return;
  355. }
  356. auto mng = Manage(gfn);
  357. ASSERT_EQ(3, mng->func_graphs().size());
  358. ASSERT_EQ(1, mng->roots().size());
  359. ASSERT_EQ(4, gfn->nodes().size());
  360. ASSERT_EQ(20, mng->all_nodes().size());
  361. ASSERT_EQ(25, mng->node_users().size());
  362. CheckAnalysisSize(mng);
  363. }
  364. TEST_F(TestManager, test_deep_nested_manual) {
  365. FuncGraphPtr f = std::make_shared<FuncGraph>();
  366. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  367. FuncGraphPtr h = std::make_shared<FuncGraph>();
  368. ParameterPtr x = f->add_parameter();
  369. ParameterPtr y = f->add_parameter();
  370. ParameterPtr z = f->add_parameter();
  371. std::vector<AnfNodePtr> inputs;
  372. inputs.push_back(NewValueNode(fg));
  373. inputs.push_back(x);
  374. inputs.push_back(y);
  375. CNodePtr cnode_1 = f->NewCNode(inputs);
  376. inputs.clear();
  377. inputs.push_back(cnode_1);
  378. inputs.push_back(NewValueNode(prim::kPrimReturn));
  379. CNodePtr cnode_0 = f->NewCNode(inputs);
  380. f->set_return(cnode_0);
  381. ParameterPtr x1 = fg->add_parameter();
  382. ParameterPtr y1 = fg->add_parameter();
  383. inputs.clear();
  384. inputs.push_back(NewValueNode(h));
  385. inputs.push_back(x1);
  386. CNodePtr cnode_3 = fg->NewCNode(inputs);
  387. inputs.clear();
  388. inputs.push_back(cnode_3);
  389. inputs.push_back(NewValueNode(prim::kPrimReturn));
  390. CNodePtr cnode_2 = fg->NewCNode(inputs);
  391. fg->set_return(cnode_2);
  392. ParameterPtr x2 = h->add_parameter();
  393. inputs.clear();
  394. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  395. inputs.push_back(x2);
  396. inputs.push_back(y1);
  397. CNodePtr cnode_6 = h->NewCNode(inputs);
  398. inputs.clear();
  399. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  400. inputs.push_back(z);
  401. inputs.push_back(cnode_6);
  402. CNodePtr cnode_5 = h->NewCNode(inputs);
  403. inputs.clear();
  404. inputs.push_back(cnode_5);
  405. inputs.push_back(NewValueNode(prim::kPrimReturn));
  406. CNodePtr cnode_4 = h->NewCNode(inputs);
  407. h->set_return(cnode_4);
  408. auto mng = Manage(f);
  409. ASSERT_EQ(3, mng->func_graphs().size());
  410. ASSERT_EQ(1, mng->roots().size());
  411. ASSERT_EQ(20, mng->all_nodes().size());
  412. CheckAnalysisSize(mng);
  413. }
  414. TEST_F(TestManager, test_parent1_manual) {
  415. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  416. Parameter param(fg);
  417. std::vector<AnfNodePtr> params;
  418. CNodePtr app = std::make_shared<CNode>(params, fg);
  419. fg->set_return(app);
  420. fg->set_parameters(params);
  421. std::shared_ptr<FuncGraphManager> manager = MakeManager();
  422. manager->AddFuncGraph(fg, true);
  423. FuncGraphPtr p = fg->parent();
  424. assert(p == nullptr);
  425. }
  426. TEST_F(TestManager, test_parent_manual) {
  427. auto prim_scalar_add = prim::kPrimScalarAdd;
  428. FuncGraphPtr fg = MakeFuncGraph(prim_scalar_add);
  429. std::shared_ptr<FuncGraphManager> manager = MakeManager();
  430. manager->AddFuncGraph(fg);
  431. FuncGraphPtr p = fg->parent();
  432. assert(p == nullptr);
  433. }
  434. TEST_F(TestManager, test_flat) {
  435. std::vector<std::shared_ptr<Stage>> stages;
  436. std::vector<std::string> specs = {"nodes=X:x", "parents=", "fvs_direct="};
  437. std::map<std::string, int> size_list;
  438. size_list["nodes"] = 2;
  439. }
  440. TEST_F(TestManager, test_nested) {
  441. std::vector<std::shared_ptr<Stage>> stages;
  442. std::vector<std::string> specs = {"nodes=X:x", "parent=g->X", "fvs_direct=g:x"};
  443. std::map<std::string, int> size_list;
  444. return;
  445. }
  446. TEST_F(TestManager, test_calls) {
  447. std::vector<std::shared_ptr<Stage>> stages;
  448. std::vector<std::string> specs = {"parents=g->X; h->X", "children=X:g,h", "scopes=X:X,g,h; g:g; h:h",
  449. "fvs_direct=h:a", "fvs_total=h:a; g:h"};
  450. std::map<std::string, int> size_list;
  451. return;
  452. }
  453. TEST_F(TestManager, test_unused_param) {
  454. std::vector<std::shared_ptr<Stage>> stages;
  455. std::vector<std::string> specs = {"nodes=X:x,y"};
  456. std::map<std::string, int> size_list;
  457. }
  458. TEST_F(TestManager, test_cannot_replace_return) {
  459. FuncGraphPtr fg = getPyFun("test_cannot_replace_return");
  460. ASSERT_NE(fg, nullptr);
  461. auto mng = Manage(fg);
  462. ASSERT_EQ(fg->manager(), mng);
  463. ASSERT_NE(mng, nullptr);
  464. ASSERT_GT(fg->parameters().size(), 0);
  465. ASSERT_FALSE(mng->Replace(fg->get_return(), fg->parameters()[0]));
  466. }
  467. TEST_F(TestManager, test_weak_manager) {
  468. FuncGraphPtr fg = getPyFun("ir_get_fn");
  469. auto mng1 = MakeManager({fg}, false);
  470. ASSERT_EQ(fg->manager(), nullptr);
  471. auto mng2 = MakeManager({fg}, true);
  472. ASSERT_EQ(fg->manager(), mng2);
  473. auto mng3 = MakeManager({fg}, false);
  474. ASSERT_EQ(fg->manager(), mng2);
  475. }
  476. TEST_F(TestManager, test_drop_root) {
  477. FuncGraphPtr fg = getPyFun("ir_get_fn");
  478. auto mng = Manage(fg);
  479. const auto &fgs = mng->func_graphs();
  480. ASSERT_TRUE(fgs.contains(fg));
  481. FuncGraphSet s;
  482. s.add(fg);
  483. mng->MaybeDropFuncGraphs(s);
  484. ASSERT_TRUE(fgs.contains(fg));
  485. }
  486. TEST_F(TestManager, test_keep_roots) {
  487. FuncGraphPtr fg1 = getPyFun("ir_get_fn");
  488. FuncGraphPtr fg2 = getPyFun("test_cannot_replace_return");
  489. auto mng = Manage(fg1);
  490. ASSERT_EQ(mng->func_graphs().size(), (size_t)1);
  491. ASSERT_TRUE(mng->func_graphs().contains(fg1));
  492. mng->AddFuncGraph(fg2);
  493. ASSERT_EQ(mng->func_graphs().size(), 2);
  494. ASSERT_TRUE(mng->func_graphs().contains(fg2));
  495. mng->KeepRoots();
  496. ASSERT_EQ(mng->func_graphs().size(), 1);
  497. ASSERT_TRUE(mng->func_graphs().contains(fg1));
  498. mng->KeepRoots({fg2});
  499. ASSERT_EQ(mng->func_graphs().size(), 1);
  500. ASSERT_TRUE(mng->func_graphs().contains(fg2));
  501. }
  502. TEST_F(TestManager, test_keep_roots_recursion) {
  503. return;
  504. FuncGraphPtr fg = getPyFun("test_keep_roots_recursion");
  505. ASSERT_NE(fg, nullptr);
  506. auto mng = Manage(fg);
  507. parse::ResolveAll(mng);
  508. ASSERT_NE(mng, nullptr);
  509. ASSERT_EQ(mng->func_graphs().size(), 4);
  510. ASSERT_GT(fg->parameters().size(), 0);
  511. mng->Replace(fg->output(), fg->parameters()[0]);
  512. ASSERT_EQ(mng->func_graphs().size(), 3);
  513. mng->KeepRoots();
  514. ASSERT_EQ(mng->func_graphs().size(), 1);
  515. }
  516. } // namespace mindspore