You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

manager_test.cc 18 kB


  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/common_test.h"
  17. #include "common/py_func_graph_fetcher.h"
  18. #include "ir/dtype.h"
  19. #include "ir/manager.h"
  20. #include "ir/func_graph_cloner.h"
  21. #include "pipeline/jit/parse/parse.h"
  22. #include "frontend/operator/ops.h"
  23. #include "utils/log_adapter.h"
  24. #include "debug/draw.h"
  25. #include "utils/label.h"
  26. namespace mindspore {
  27. namespace {
  28. std::vector<std::string> SplitString(std::string str, std::string pattern) {
  29. std::string::size_type pos;
  30. std::vector<std::string> result;
  31. str += pattern;
  32. std::string::size_type size = str.size();
  33. for (std::string::size_type i = 0; i < size; ++i) {
  34. pos = str.find(pattern, i);
  35. if (pos < size) {
  36. std::string s = str.substr(i, pos - i);
  37. result.push_back(s);
  38. i = pos + pattern.size() - 1;
  39. }
  40. }
  41. return result;
  42. }
  43. } // namespace
  44. using std::dynamic_pointer_cast;
  45. using TodoList = std::vector<std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>>;
  46. using TodoListItem = std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>;
  47. class NestingSpecs;
  48. class Stage {
  49. public:
  50. explicit Stage(std::vector<std::string> specs) {
  51. for (auto arg : specs) {
  52. auto spec = SplitString(arg, "=");
  53. if (spec.size() <= 1) {
  54. continue;
  55. }
  56. std::shared_ptr<NestingSpecs> nesting = std::make_shared<NestingSpecs>(this, spec[1]);
  57. specs_[ToFullString(spec[0])] = nesting;
  58. }
  59. }
  60. ~Stage() {}
  61. std::map<std::string, std::string>& subs() { return subs_; }
  62. void set_subs(const std::map<std::string, std::string>& subs) { subs_ = subs; }
  63. private:
  64. std::string ToFullString(std::string s) {
  65. if (s.find("fv") != std::string::npos) {
  66. s = s.replace(s.find("fv"), 2, "free_variable");
  67. }
  68. if (s.find("deps") != std::string::npos) {
  69. s = s.replace(s.find("deps"), 4, "dependencies");
  70. }
  71. return s;
  72. }
  73. std::map<std::string, std::shared_ptr<NestingSpecs>> specs_;
  74. std::map<std::string, std::string> subs_;
  75. };
  76. class NestingSpecs {
  77. public:
  78. NestingSpecs(Stage* stage, std::string specs) : stage_(stage) { ParseSpecs(specs); }
  79. ~NestingSpecs() {}
  80. std::string Name(Any node) {
  81. std::string name = label_manage::Label(node.cast<AnfNodePtr>()->debug_info());
  82. if (stage_->subs().find(name) != stage_->subs().end()) {
  83. return stage_->subs()[name];
  84. }
  85. return name;
  86. }
  87. void Check(std::shared_ptr<DepComputer> results) {
  88. if (expected_.empty() && expected_recursive_.empty()) {
  89. return;
  90. }
  91. auto parent = dynamic_pointer_cast<ParentComputer>(results);
  92. if (parent != nullptr) {
  93. CheckParent(parent);
  94. return;
  95. }
  96. auto recursive = dynamic_pointer_cast<RecursiveComputer>(results);
  97. if (recursive != nullptr) {
  98. CheckRecursive(recursive);
  99. return;
  100. }
  101. }
  102. private:
  103. void ParseSpecs(std::string specs) {
  104. if (specs.empty()) {
  105. return;
  106. }
  107. std::vector<std::string> str_list = SplitString(specs, ";");
  108. for (auto spec : str_list) {
  109. spec.erase(0, spec.find_first_not_of(" "));
  110. spec.erase(spec.find_last_not_of(" ") + 1);
  111. if (spec.empty()) {
  112. continue;
  113. }
  114. if (spec.find("->") != std::string::npos) {
  115. auto substr = SplitString(spec, "->");
  116. ASSERT_GT(substr.size(), 1);
  117. auto key = substr[0];
  118. auto value = substr[1];
  119. if (!value.empty()) {
  120. expected_[key] = {value};
  121. }
  122. } else if (spec.find(":") != std::string::npos) {
  123. auto substr = SplitString(spec, ":");
  124. ASSERT_GT(substr.size(), 1);
  125. auto key = substr[0];
  126. auto values = SplitString(substr[1], ",");
  127. std::set<std::string> values_set(values.begin(), values.end());
  128. if (!values_set.empty()) {
  129. expected_[key] = values_set;
  130. }
  131. } else {
  132. expected_recursive_[spec] = true;
  133. }
  134. }
  135. }
  136. void CheckParent(std::shared_ptr<ParentComputer> results) {
  137. std::map<std::string, std::set<std::string>> clean_results;
  138. for (auto& iter : results->parent_analysis()) {
  139. auto key = iter.first;
  140. auto value = iter.second;
  141. if (key == nullptr) {
  142. continue;
  143. }
  144. std::string k = Name(key);
  145. std::set<std::string> v;
  146. if (value != nullptr && !Name(value).empty()) {
  147. v.insert(Name(value));
  148. }
  149. if (!v.empty()) {
  150. clean_results[k] = v;
  151. }
  152. }
  153. ASSERT_EQ(clean_results, expected_);
  154. }
  155. void CheckRecursive(std::shared_ptr<RecursiveComputer> results) {
  156. std::map<std::string, bool> clean_results;
  157. for (auto iter = results->recursive_analysis().begin(); iter != results->recursive_analysis().end(); ++iter) {
  158. auto key = iter->first;
  159. auto value = iter->second;
  160. if (key == nullptr) {
  161. continue;
  162. }
  163. std::string k = Name(key);
  164. clean_results[k] = value;
  165. }
  166. ASSERT_EQ(clean_results, expected_recursive_);
  167. }
  168. private:
  169. Stage* stage_;
  170. std::map<std::string, std::set<std::string>> expected_;
  171. std::map<std::string, bool> expected_recursive_;
  172. };
  173. bool CheckUsers(std::shared_ptr<FuncGraphManager> manager) {
  174. for (auto node : manager->all_nodes()) {
  175. if (node->isa<CNode>()) {
  176. auto& inputs = node->cast<CNodePtr>()->inputs();
  177. for (size_t i = 0; i < inputs.size(); ++i) {
  178. auto inp = inputs[i];
  179. if (!manager->all_nodes().contains(inp)) {
  180. return false;
  181. }
  182. if (manager->node_users().find(inp) != manager->node_users().end()) {
  183. auto users = manager->node_users()[inp];
  184. if (!users.contains(make_pair(node, i))) {
  185. return false;
  186. }
  187. }
  188. }
  189. }
  190. if (manager->node_users().find(node) != manager->node_users().end()) {
  191. auto users = manager->node_users()[node];
  192. for (auto iter = users.begin(); iter != users.end(); ++iter) {
  193. auto node2 = iter->first;
  194. auto key = iter->second;
  195. if (!manager->all_nodes().contains(node2)) {
  196. return false;
  197. }
  198. if (node2->cast<CNodePtr>()->input(key) != node) {
  199. return false;
  200. }
  201. }
  202. }
  203. }
  204. return true;
  205. }
  206. class TestManager : public UT::Common {
  207. public:
  208. TestManager() : getPyFun("gtest_input.ir.manager_test") {}
  209. void CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng);
  210. public:
  211. std::vector<PrimitivePtr> swaps;
  212. UT::PyFuncGraphFetcher getPyFun;
  213. };
  214. FuncGraphPtr MakeFuncGraph(PrimitivePtr prim) {
  215. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  216. ParameterPtr x = func_graph->add_parameter();
  217. ParameterPtr y = func_graph->add_parameter();
  218. std::vector<AnfNodePtr> inputs;
  219. inputs.push_back(NewValueNode(prim));
  220. inputs.push_back(x);
  221. inputs.push_back(y);
  222. CNodePtr cnode_add = func_graph->NewCNode(inputs);
  223. inputs.clear();
  224. inputs.push_back(NewValueNode(prim::kPrimReturn));
  225. inputs.push_back(cnode_add);
  226. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  227. func_graph->set_return(cnode_return);
  228. return func_graph;
  229. }
  230. std::vector<FuncGraphPtr> MakeNestedGraph() {
  231. /*
  232. *def f(x):
  233. * def g():
  234. * return x
  235. * return g
  236. */
  237. FuncGraphPtr f = std::make_shared<FuncGraph>();
  238. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  239. ParameterPtr x = f->add_parameter();
  240. std::vector<AnfNodePtr> inputs;
  241. inputs.push_back(NewValueNode(fg));
  242. inputs.push_back(NewValueNode(prim::kPrimReturn));
  243. CNodePtr cnode_f = f->NewCNode(inputs);
  244. f->set_return(cnode_f);
  245. inputs.clear();
  246. inputs.push_back(NewValueNode(prim::kPrimReturn));
  247. inputs.push_back(x);
  248. CNodePtr cnode_g = fg->NewCNode(inputs);
  249. fg->set_return(cnode_g);
  250. std::vector<FuncGraphPtr> result = {f, fg};
  251. return result;
  252. }
  253. std::vector<FuncGraphPtr> MakeNestedGraph2() {
  254. /* build a closure func_graph */
  255. /*
  256. *def foo(x, y):
  257. * def bar(x1):
  258. * return x1 + y
  259. * return bar(x)
  260. */
  261. FuncGraphPtr graph_foo = std::make_shared<FuncGraph>();
  262. ParameterPtr x = graph_foo->add_parameter();
  263. ParameterPtr y = graph_foo->add_parameter();
  264. std::vector<AnfNodePtr> inputs;
  265. // build func_graph bar
  266. FuncGraphPtr graph_bar = std::make_shared<FuncGraph>();
  267. ParameterPtr x1 = graph_bar->add_parameter();
  268. inputs.clear();
  269. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  270. inputs.push_back(x1);
  271. inputs.push_back(y);
  272. CNodePtr cnode_add = graph_bar->NewCNode(inputs);
  273. inputs.clear();
  274. inputs.push_back(NewValueNode(prim::kPrimReturn));
  275. inputs.push_back(cnode_add);
  276. CNodePtr cnode_return = graph_bar->NewCNode(inputs);
  277. graph_bar->set_return(cnode_return);
  278. // build func_graph foo
  279. inputs.clear();
  280. inputs.push_back(NewValueNode(graph_bar));
  281. inputs.push_back(x);
  282. CNodePtr cnode_graph_bar = graph_foo->NewCNode(inputs);
  283. inputs.clear();
  284. inputs.push_back(NewValueNode(prim::kPrimReturn));
  285. inputs.push_back(cnode_graph_bar);
  286. cnode_return = graph_foo->NewCNode(inputs);
  287. graph_foo->set_return(cnode_return);
  288. std::vector<FuncGraphPtr> result = {graph_foo, graph_bar};
  289. return result;
  290. }
  291. // Add TestManager::CheckManager function to checkout the result
  292. void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) {
  293. auto size = mng->func_graphs().size();
  294. ASSERT_EQ(size, mng->free_variables_total().size());
  295. }
  296. TEST_F(TestManager, test_scalar_add_manual) {
  297. auto prim_scalar_add = prim::kPrimScalarAdd;
  298. FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
  299. auto mng = Manage(func_graph);
  300. }
  301. TEST_F(TestManager, test_scalar_replace) {
  302. auto prim_scalar_add = prim::kPrimScalarAdd;
  303. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  304. ParameterPtr x = func_graph->add_parameter();
  305. ParameterPtr y = func_graph->add_parameter();
  306. std::vector<AnfNodePtr> inputs;
  307. inputs.push_back(NewValueNode(prim_scalar_add));
  308. inputs.push_back(x);
  309. inputs.push_back(y);
  310. CNodePtr cnode_add = func_graph->NewCNode(inputs);
  311. inputs.clear();
  312. inputs.push_back(NewValueNode(prim::kPrimReturn));
  313. inputs.push_back(cnode_add);
  314. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  315. func_graph->set_return(cnode_return);
  316. auto mng = Manage(func_graph);
  317. std::cout << "start " << x->ToString() << std::endl;
  318. mng->Replace(cnode_add, x);
  319. }
  320. TEST_F(TestManager, test_nested_manual) {
  321. auto graphs = MakeNestedGraph();
  322. auto f = graphs[0];
  323. auto g = graphs[1];
  324. auto mng = Manage(f);
  325. ASSERT_EQ(6, mng->all_nodes().size());
  326. ASSERT_EQ(2, mng->func_graphs().size());
  327. ASSERT_EQ(4, mng->node_users().size());
  328. ASSERT_EQ(1, mng->roots().size());
  329. CheckAnalysisSize(mng);
  330. ASSERT_EQ(2, f->nodes().size());
  331. ASSERT_EQ(1, g->nodes().size());
  332. auto &users = mng->node_users();
  333. for (auto& iter : users) {
  334. ASSERT_EQ(1, iter.second.size());
  335. }
  336. ASSERT_EQ(1, f->func_graphs_used().size());
  337. ASSERT_EQ(0, g->func_graphs_used().size());
  338. ASSERT_EQ(0, f->free_variables().size());
  339. ASSERT_EQ(1, g->free_variables().size());
  340. auto fv_total = mng->free_variables_total();
  341. ASSERT_EQ(0, fv_total[f].size());
  342. ASSERT_EQ(1, fv_total[g].size());
  343. ASSERT_EQ(0, f->func_graph_cnodes_index().size());
  344. ASSERT_EQ(1, g->func_graph_cnodes_index().size());
  345. }
  346. TEST_F(TestManager, test_deep_nested2_manual) {
  347. // create parser
  348. FuncGraphPtr func_graph = getPyFun("test_custom");
  349. return;
  350. // parse ast to func graph
  351. FuncGraphPtr gfn = BasicClone(func_graph);
  352. if (gfn == nullptr) {
  353. return;
  354. }
  355. auto mng = Manage(gfn);
  356. ASSERT_EQ(3, mng->func_graphs().size());
  357. ASSERT_EQ(1, mng->roots().size());
  358. ASSERT_EQ(4, gfn->nodes().size());
  359. ASSERT_EQ(20, mng->all_nodes().size());
  360. ASSERT_EQ(25, mng->node_users().size());
  361. CheckAnalysisSize(mng);
  362. }
  363. TEST_F(TestManager, test_deep_nested_manual) {
  364. FuncGraphPtr f = std::make_shared<FuncGraph>();
  365. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  366. FuncGraphPtr h = std::make_shared<FuncGraph>();
  367. ParameterPtr x = f->add_parameter();
  368. ParameterPtr y = f->add_parameter();
  369. ParameterPtr z = f->add_parameter();
  370. std::vector<AnfNodePtr> inputs;
  371. inputs.push_back(NewValueNode(fg));
  372. inputs.push_back(x);
  373. inputs.push_back(y);
  374. CNodePtr cnode_1 = f->NewCNode(inputs);
  375. inputs.clear();
  376. inputs.push_back(cnode_1);
  377. inputs.push_back(NewValueNode(prim::kPrimReturn));
  378. CNodePtr cnode_0 = f->NewCNode(inputs);
  379. f->set_return(cnode_0);
  380. ParameterPtr x1 = fg->add_parameter();
  381. ParameterPtr y1 = fg->add_parameter();
  382. inputs.clear();
  383. inputs.push_back(NewValueNode(h));
  384. inputs.push_back(x1);
  385. CNodePtr cnode_3 = fg->NewCNode(inputs);
  386. inputs.clear();
  387. inputs.push_back(cnode_3);
  388. inputs.push_back(NewValueNode(prim::kPrimReturn));
  389. CNodePtr cnode_2 = fg->NewCNode(inputs);
  390. fg->set_return(cnode_2);
  391. ParameterPtr x2 = h->add_parameter();
  392. inputs.clear();
  393. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  394. inputs.push_back(x2);
  395. inputs.push_back(y1);
  396. CNodePtr cnode_6 = h->NewCNode(inputs);
  397. inputs.clear();
  398. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  399. inputs.push_back(z);
  400. inputs.push_back(cnode_6);
  401. CNodePtr cnode_5 = h->NewCNode(inputs);
  402. inputs.clear();
  403. inputs.push_back(cnode_5);
  404. inputs.push_back(NewValueNode(prim::kPrimReturn));
  405. CNodePtr cnode_4 = h->NewCNode(inputs);
  406. h->set_return(cnode_4);
  407. auto mng = Manage(f);
  408. ASSERT_EQ(3, mng->func_graphs().size());
  409. ASSERT_EQ(1, mng->roots().size());
  410. ASSERT_EQ(20, mng->all_nodes().size());
  411. CheckAnalysisSize(mng);
  412. }
  413. TEST_F(TestManager, test_parent1_manual) {
  414. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  415. Parameter param(fg);
  416. std::vector<AnfNodePtr> params;
  417. CNodePtr app = std::make_shared<CNode>(params, fg);
  418. fg->set_return(app);
  419. fg->set_parameters(params);
  420. std::shared_ptr<FuncGraphManager> manager = MakeManager();
  421. manager->AddFuncGraph(fg, true);
  422. FuncGraphPtr p = fg->parent();
  423. assert(p == nullptr);
  424. }
  425. TEST_F(TestManager, test_parent_manual) {
  426. auto prim_scalar_add = prim::kPrimScalarAdd;
  427. FuncGraphPtr fg = MakeFuncGraph(prim_scalar_add);
  428. std::shared_ptr<FuncGraphManager> manager = MakeManager();
  429. manager->AddFuncGraph(fg);
  430. FuncGraphPtr p = fg->parent();
  431. assert(p == nullptr);
  432. }
  433. TEST_F(TestManager, test_flat) {
  434. std::vector<std::shared_ptr<Stage>> stages;
  435. std::vector<std::string> specs = {"nodes=X:x", "parents=", "fvs_direct="};
  436. std::map<std::string, int> size_list;
  437. size_list["nodes"] = 2;
  438. }
  439. TEST_F(TestManager, test_nested) {
  440. std::vector<std::shared_ptr<Stage>> stages;
  441. std::vector<std::string> specs = {"nodes=X:x", "parent=g->X", "fvs_direct=g:x"};
  442. std::map<std::string, int> size_list;
  443. return;
  444. }
  445. TEST_F(TestManager, test_calls) {
  446. std::vector<std::shared_ptr<Stage>> stages;
  447. std::vector<std::string> specs = {"parents=g->X; h->X", "children=X:g,h", "scopes=X:X,g,h; g:g; h:h",
  448. "fvs_direct=h:a", "fvs_total=h:a; g:h"};
  449. std::map<std::string, int> size_list;
  450. return;
  451. }
  452. TEST_F(TestManager, test_unused_param) {
  453. std::vector<std::shared_ptr<Stage>> stages;
  454. std::vector<std::string> specs = {"nodes=X:x,y"};
  455. std::map<std::string, int> size_list;
  456. }
  457. TEST_F(TestManager, test_cannot_replace_return) {
  458. FuncGraphPtr fg = getPyFun("test_cannot_replace_return");
  459. ASSERT_NE(fg, nullptr);
  460. auto mng = Manage(fg);
  461. ASSERT_EQ(fg->manager(), mng);
  462. ASSERT_NE(mng, nullptr);
  463. ASSERT_GT(fg->parameters().size(), 0);
  464. ASSERT_FALSE(mng->Replace(fg->get_return(), fg->parameters()[0]));
  465. }
  466. TEST_F(TestManager, test_weak_manager) {
  467. FuncGraphPtr fg = getPyFun("ir_get_fn");
  468. auto mng1 = MakeManager({fg}, false);
  469. ASSERT_EQ(fg->manager(), nullptr);
  470. auto mng2 = MakeManager({fg}, true);
  471. ASSERT_EQ(fg->manager(), mng2);
  472. auto mng3 = MakeManager({fg}, false);
  473. ASSERT_EQ(fg->manager(), mng2);
  474. }
  475. TEST_F(TestManager, test_drop_root) {
  476. FuncGraphPtr fg = getPyFun("ir_get_fn");
  477. auto mng = Manage(fg);
  478. const auto &fgs = mng->func_graphs();
  479. ASSERT_TRUE(fgs.contains(fg));
  480. FuncGraphSet s;
  481. s.add(fg);
  482. mng->MaybeDropFuncGraphs(s);
  483. ASSERT_TRUE(fgs.contains(fg));
  484. }
  485. TEST_F(TestManager, test_keep_roots) {
  486. FuncGraphPtr fg1 = getPyFun("ir_get_fn");
  487. FuncGraphPtr fg2 = getPyFun("test_cannot_replace_return");
  488. auto mng = Manage(fg1);
  489. ASSERT_EQ(mng->func_graphs().size(), (size_t)1);
  490. ASSERT_TRUE(mng->func_graphs().contains(fg1));
  491. mng->AddFuncGraph(fg2);
  492. ASSERT_EQ(mng->func_graphs().size(), 2);
  493. ASSERT_TRUE(mng->func_graphs().contains(fg2));
  494. mng->KeepRoots();
  495. ASSERT_EQ(mng->func_graphs().size(), 1);
  496. ASSERT_TRUE(mng->func_graphs().contains(fg1));
  497. mng->KeepRoots({fg2});
  498. ASSERT_EQ(mng->func_graphs().size(), 1);
  499. ASSERT_TRUE(mng->func_graphs().contains(fg2));
  500. }
  501. TEST_F(TestManager, test_keep_roots_recursion) {
  502. return;
  503. FuncGraphPtr fg = getPyFun("test_keep_roots_recursion");
  504. ASSERT_NE(fg, nullptr);
  505. auto mng = Manage(fg);
  506. parse::ResolveAll(mng);
  507. ASSERT_NE(mng, nullptr);
  508. ASSERT_EQ(mng->func_graphs().size(), 4);
  509. ASSERT_GT(fg->parameters().size(), 0);
  510. mng->Replace(fg->output(), fg->parameters()[0]);
  511. ASSERT_EQ(mng->func_graphs().size(), 3);
  512. mng->KeepRoots();
  513. ASSERT_EQ(mng->func_graphs().size(), 1);
  514. }
  515. } // namespace mindspore