You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dead_node_eliminate.cc 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/dead_node_eliminate.h"
  17. #include <memory>
  18. #include <vector>
  19. #include <deque>
  20. #include <set>
  21. #include <utility>
  22. #include "utils/utils.h"
  23. #include "base/core_ops.h"
  24. #include "utils/func_graph_analyzer.h"
  25. namespace mindspore {
  26. namespace opt {
  27. namespace {
  28. bool IsFuncGraphCallNode(const AnfNodePtr &node) {
  29. if (!node->isa<CNode>()) {
  30. return false;
  31. }
  32. auto input0 = node->cast<CNodePtr>()->input(0);
  33. if (IsValueNode<Primitive>(input0)) {
  34. return false;
  35. }
  36. return true;
  37. }
  38. } // namespace
  39. class VisitContext {
  40. public:
  41. VisitContext() = default;
  42. explicit VisitContext(const std::vector<int64_t> &index_stack) { (void)index_stacks_.insert(index_stack); }
  43. ~VisitContext() = default;
  44. bool Add(const std::vector<int64_t> &index_stack) {
  45. if (index_stacks_.find(index_stack) != index_stacks_.end()) {
  46. return false;
  47. }
  48. (void)index_stacks_.insert(index_stack);
  49. return true;
  50. }
  51. bool IndexVisited(int64_t index) const {
  52. return std::any_of(index_stacks_.begin(), index_stacks_.end(), [&index](const std::vector<int64_t> &index_stack) {
  53. return !index_stack.empty() && index_stack.back() == index;
  54. });
  55. }
  56. std::set<std::vector<int64_t>> index_stacks_;
  57. };
  58. using VisitContextPtr = std::shared_ptr<VisitContext>;
  59. class ContextManager {
  60. public:
  61. ContextManager() = default;
  62. ~ContextManager() = default;
  63. HashMap<AnfNodePtr, VisitContextPtr> contexts_;
  64. bool AddContext(const AnfNodePtr &node, const std::vector<int64_t> &index_stack) {
  65. auto it = contexts_.find(node);
  66. if (it == contexts_.end()) {
  67. MS_LOG(DEBUG) << "Add node: " << node->DebugString();
  68. contexts_[node] = std::make_shared<VisitContext>(index_stack);
  69. return true;
  70. }
  71. return it->second->Add(index_stack);
  72. }
  73. bool IndexVisited(const CNodePtr &node, int64_t index) const {
  74. auto it = contexts_.find(node);
  75. if (it == contexts_.end()) {
  76. return false;
  77. }
  78. return it->second->IndexVisited(index);
  79. }
  80. };
  81. void VisitNode(const AnfNodePtr &node, const FuncGraphAnalyzer &analyzer, std::vector<int64_t> index_stack, size_t seen,
  82. ContextManager *context_manager) {
  83. if (IS_OUTPUT_ON(DEBUG)) {
  84. MS_LOG(DEBUG) << "Visit node: " << node->DebugString();
  85. for (size_t i = 0; i < index_stack.size(); i++) {
  86. MS_LOG(DEBUG) << "index_stack[" << i << "]: " << index_stack[i];
  87. }
  88. }
  89. // If context exist, node need visit again to avoid repeatedly visiting.
  90. if (!context_manager->AddContext(node, index_stack)) {
  91. return;
  92. }
  93. node->seen_ = seen;
  94. if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
  95. auto tuple_getitem = node->cast<CNodePtr>();
  96. // Get cur index
  97. auto output_index_value_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem);
  98. MS_EXCEPTION_IF_NULL(output_index_value_node);
  99. auto value_node = output_index_value_node->cast<ValueNodePtr>();
  100. MS_EXCEPTION_IF_NULL(value_node);
  101. auto output_idx = LongToSize(GetValue<int64_t>(value_node->value()));
  102. index_stack.push_back(output_idx);
  103. auto real_input = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
  104. VisitNode(real_input, analyzer, index_stack, seen, context_manager);
  105. return;
  106. }
  107. if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
  108. // If make_tuple in make_tuple, visit may start with inner tuple_getitem.
  109. if (index_stack.empty()) {
  110. return;
  111. }
  112. auto make_tuple = node->cast<CNodePtr>();
  113. auto output_idx = index_stack.back();
  114. index_stack.pop_back();
  115. VisitNode(make_tuple->input(1 + output_idx), analyzer, index_stack, seen, context_manager);
  116. return;
  117. }
  118. if (IsFuncGraphCallNode(node)) {
  119. const auto &caller_func_graphs = analyzer.GetCallerFuncGraphs(node);
  120. for (const auto &fg : caller_func_graphs) {
  121. auto new_index_stack = std::vector<int64_t>(index_stack);
  122. VisitNode(fg->output(), analyzer, new_index_stack, seen, context_manager);
  123. }
  124. return;
  125. }
  126. if (node->isa<Parameter>()) {
  127. const auto &func_callers = analyzer.GetFuncGraphCallers(node->func_graph());
  128. for (auto &caller : func_callers) {
  129. const auto &args = analyzer.GetArg(node, caller);
  130. auto new_index_stack = std::vector<int64_t>(index_stack);
  131. for (const auto &arg : args) {
  132. VisitNode(arg, analyzer, new_index_stack, seen, context_manager);
  133. }
  134. }
  135. return;
  136. }
  137. if (node->isa<ValueTuple>()) {
  138. // TupleGetItem's input may not be a MakeTuple but a ValueTuple.
  139. return;
  140. }
  141. MS_LOG(DEBUG) << "Reach the end node: " << node->DebugString() << ", but index stack is not empty.";
  142. }
  143. std::vector<AnfNodePtr> GenerateOutputTempGetItems(const FuncGraphPtr &func_graph) {
  144. std::vector<AnfNodePtr> output_tmp_getitems;
  145. std::deque<AnfNodePtr> todo = {func_graph->output()};
  146. while (!todo.empty()) {
  147. const auto node = todo.back();
  148. todo.pop_back();
  149. MS_EXCEPTION_IF_NULL(node->abstract());
  150. if (!node->abstract()->isa<abstract::AbstractTuple>()) {
  151. if (node != func_graph->output()) {
  152. output_tmp_getitems.emplace_back(node);
  153. }
  154. continue;
  155. }
  156. auto abstract_tuple = node->abstract()->cast<abstract::AbstractTuplePtr>();
  157. MS_EXCEPTION_IF_NULL(abstract_tuple);
  158. int64_t index = 0;
  159. for (const auto &elm : abstract_tuple->elements()) {
  160. auto new_tuple_getitem =
  161. func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(MakeValue(index))});
  162. new_tuple_getitem->set_abstract(elm);
  163. MS_LOG(INFO) << "New tuple getitem: " << new_tuple_getitem->DebugString() << ", index: " << index;
  164. todo.push_front(new_tuple_getitem);
  165. index++;
  166. }
  167. }
  168. return output_tmp_getitems;
  169. }
  170. bool IsScalarValueNode(const AnfNodePtr &node) {
  171. if (!IsValueNode<Scalar>(node)) {
  172. return false;
  173. }
  174. if (node->abstract() == nullptr) {
  175. return false;
  176. }
  177. return node->abstract()->isa<abstract::AbstractScalar>();
  178. }
  179. AnfNodePtr MakeScalarZero() {
  180. auto zero = NewValueNode(MakeValue(0));
  181. auto abs = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(0));
  182. zero->set_abstract(abs);
  183. return zero;
  184. }
  185. bool EraseNode(const CNodePtr &cnode, size_t input_idx, const FuncGraphManagerPtr &manager) {
  186. // Scalar(int) no need convert to Scalar(0), and Scalar(0) cannot be erased once again.
  187. auto dead_node = cnode->input(input_idx);
  188. if (IsScalarValueNode(dead_node)) {
  189. return false;
  190. }
  191. MS_LOG(WARNING) << "Erase dead node: " << dead_node->DebugString() << ", user: " << cnode->DebugString();
  192. // Can't use `Replace`, must use `SetEdge`.
  193. manager->SetEdge(cnode, input_idx, MakeScalarZero());
  194. return true;
  195. }
  196. bool EraseMakeTupleInput(const std::vector<AnfNodePtr> &make_tuples, const FuncGraphPtr &func_graph,
  197. const ContextManager &context_manager, size_t seen) {
  198. bool change = false;
  199. for (const auto &make_tuple : make_tuples) {
  200. MS_LOG(DEBUG) << "Check make_tuple:" << make_tuple->DebugString();
  201. auto make_tuple_cnode = make_tuple->cast<CNodePtr>();
  202. for (size_t i = 1; i < make_tuple_cnode->size(); i++) {
  203. // If make_tuple was not visited ,it may be a make tuple of swith_layer or addn and some other ops.
  204. auto input_edge_visited = context_manager.IndexVisited(make_tuple_cnode, i - 1);
  205. // Can use `context_manager.contexts_.find(make_tuple_cnode) != context_manager.contexts_.end()`.
  206. auto make_tuple_visited = make_tuple_cnode->seen_ == seen;
  207. if (!input_edge_visited && make_tuple_visited) {
  208. change = EraseNode(make_tuple_cnode, i, func_graph->manager()) || change;
  209. }
  210. }
  211. }
  212. return change;
  213. }
  214. void VisitValue(const ValuePtr &value, std::vector<int64_t> indexes,
  215. HashMap<ValuePtr, HashSet<int64_t>> *visited_values) {
  216. MS_EXCEPTION_IF_NULL(value);
  217. MS_LOG(DEBUG) << "Visit value:" << value->ToString();
  218. if (indexes.empty()) {
  219. MS_LOG(DEBUG) << "Indexes empty";
  220. return;
  221. }
  222. const auto visit_index = indexes.back();
  223. (*visited_values)[value].insert(visit_index);
  224. auto value_tuple = value->cast<ValueTuplePtr>();
  225. MS_EXCEPTION_IF_NULL(value_tuple);
  226. if (LongToSize(visit_index) >= value_tuple->size()) {
  227. MS_LOG(EXCEPTION) << "Index: " << visit_index << " out of range: " << value_tuple->size();
  228. }
  229. indexes.pop_back();
  230. MS_LOG(DEBUG) << "Visit index: " << visit_index;
  231. VisitValue(value_tuple->value()[LongToSize(visit_index)], indexes, visited_values);
  232. }
  233. std::pair<ValuePtr, abstract::AbstractBasePtr> EraseValue(const ValuePtr &value, const abstract::AbstractBasePtr &abs,
  234. const HashMap<ValuePtr, HashSet<int64_t>> &visited_values,
  235. bool need_erase) {
  236. if (need_erase) {
  237. auto new_value = MakeValue(0);
  238. auto new_abs = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(0));
  239. new_abs->set_value(new_value);
  240. MS_LOG(WARNING) << "Erase value:" << value->ToString();
  241. return {new_value, new_abs};
  242. }
  243. auto it = visited_values.find(value);
  244. if (it == visited_values.end()) {
  245. return {value, abs};
  246. }
  247. const auto &all_visit_index = it->second;
  248. auto value_tuple = value->cast<ValueTuplePtr>();
  249. MS_EXCEPTION_IF_NULL(value_tuple);
  250. auto abs_tuple = abs->cast<abstract::AbstractTuplePtr>();
  251. MS_EXCEPTION_IF_NULL(abs_tuple);
  252. auto new_elements = std::vector<ValuePtr>(value_tuple->value());
  253. auto new_abstracts = std::vector<abstract::AbstractBasePtr>(abs_tuple->elements());
  254. if (new_elements.size() != new_abstracts.size()) {
  255. MS_LOG(EXCEPTION) << "Value size: " << new_elements.size()
  256. << " is not equal to abstract size: " << new_abstracts.size();
  257. }
  258. bool change = false;
  259. for (size_t i = 0; i < value_tuple->value().size(); i++) {
  260. auto value_i = new_elements[i];
  261. auto abs_i = new_abstracts[i];
  262. // Avoid repeatedly erase.
  263. MS_LOG(DEBUG) << "value_i[" << i << "]: " << value_i->ToString();
  264. if (value_i->isa<Scalar>()) {
  265. continue;
  266. }
  267. bool need_erase_i = all_visit_index.find(SizeToLong(i)) == all_visit_index.end();
  268. auto [ret_value, ret_abs] = EraseValue(value_i, abs_i, visited_values, need_erase_i);
  269. if (ret_value != value_i) {
  270. new_elements[i] = ret_value;
  271. new_abstracts[i] = ret_abs;
  272. change = true;
  273. }
  274. }
  275. if (change) {
  276. value_tuple = std::make_shared<ValueTuple>(new_elements);
  277. abs_tuple = std::make_shared<abstract::AbstractTuple>(new_abstracts);
  278. abs_tuple->set_value(value_tuple);
  279. }
  280. return {value_tuple, abs_tuple};
  281. }
  282. bool EraseDeadValues(const std::vector<AnfNodePtr> &value_tuple_nodes, const ContextManager &context_manager) {
  283. bool change = false;
  284. for (const auto &value_tuple : value_tuple_nodes) {
  285. auto it = context_manager.contexts_.find(value_tuple);
  286. if (it == context_manager.contexts_.end()) {
  287. continue;
  288. }
  289. HashMap<ValuePtr, HashSet<int64_t>> visited_values;
  290. const auto value = GetValueNode(value_tuple);
  291. for (const auto &context : it->second->index_stacks_) {
  292. VisitValue(value, context, &visited_values);
  293. }
  294. // Erase the unvisited values.
  295. auto [new_value, new_abs] = EraseValue(value, value_tuple->abstract(), visited_values, false);
  296. if (new_value != value) {
  297. value_tuple->cast<ValueNodePtr>()->set_value(new_value);
  298. value_tuple->set_abstract(new_abs);
  299. MS_LOG(DEBUG) << "Set new value of node: " << value_tuple->DebugString();
  300. change = true;
  301. }
  302. }
  303. return change;
  304. }
  305. std::shared_ptr<HashSet<size_t>> GetUsedParameters(const FuncGraphPtr &func_graph) {
  306. auto used_parameter_indexes = std::make_shared<HashSet<size_t>>();
  307. if (func_graph->manager() == nullptr) {
  308. return used_parameter_indexes;
  309. }
  310. const auto &manager_node_users = func_graph->manager()->node_users();
  311. const auto &parameters = func_graph->parameters();
  312. // Traverse to find all unused parameters.
  313. size_t index = 0;
  314. for (const auto &parameter : parameters) {
  315. const auto &node_users_it = manager_node_users.find(parameter);
  316. if (node_users_it != manager_node_users.end() && !node_users_it->second.empty()) {
  317. used_parameter_indexes->insert(index);
  318. }
  319. index++;
  320. }
  321. return used_parameter_indexes;
  322. }
  323. bool EraseArg(size_t user_index, const CNodePtr &arg_user, const FuncGraphManagerPtr &manager) {
  324. size_t arg_start_idx = 0;
  325. const size_t kFuncGraphCallArgStartIdx = 2;
  326. const size_t kPartialArgStartIdx = 2;
  327. if (IsFuncGraphCallNode(arg_user)) {
  328. arg_start_idx = kFuncGraphCallArgStartIdx;
  329. } else if (IsPrimitiveCNode(arg_user, prim::kPrimPartial)) {
  330. arg_start_idx = kPartialArgStartIdx;
  331. } else {
  332. MS_LOG(EXCEPTION) << "Unexpected arg user: " << arg_user->DebugString();
  333. }
  334. if (user_index < arg_start_idx) {
  335. return false;
  336. }
  337. return EraseNode(arg_user, user_index, manager);
  338. }
  339. void VisitClosureArg(const FuncClosurePtr &closure, const CNodePtr &call, const HashSet<size_t> &used_indexes,
  340. OrderedMap<CNodePtr, HashSet<size_t>> *visited_args) {
  341. auto arg_indexes = closure->arg_indexes_;
  342. // Add call node args to all args.
  343. auto arg_users = closure->arg_users_;
  344. for (size_t i = 1; i < call->inputs().size(); i++) {
  345. arg_indexes.push_back(i);
  346. arg_users.push_back(call);
  347. }
  348. const auto &fg = closure->func_graph_;
  349. if (arg_indexes.size() != fg->parameters().size()) {
  350. MS_LOG(EXCEPTION) << "Args size: " << arg_indexes.size()
  351. << " is not equal to parameters size: " << fg->parameters().size()
  352. << ". call: " << call->DebugString() << ", fg: " << fg->ToString();
  353. }
  354. for (size_t i = 0; i < arg_users.size(); i++) {
  355. // Insert a empty set to keep arg user record in map.
  356. if (visited_args->find(arg_users[i]) == visited_args->end()) {
  357. (*visited_args)[arg_users[i]] = HashSet<size_t>();
  358. }
  359. if (used_indexes.find(i) != used_indexes.end()) {
  360. MS_LOG(DEBUG) << "Visit arg user: " << arg_users[i]->DebugString() << ", idx: " << arg_indexes[i];
  361. (*visited_args)[arg_users[i]].insert(arg_indexes[i]);
  362. }
  363. }
  364. }
  365. // If the parameter is a function parameter, the arg will be converted to a DeadNod after renormalize, so the arg need
  366. // to be erased.
  367. bool EraseUnusedArgs(const std::vector<AnfNodePtr> &all_calls, const FuncGraphAnalyzer &analyzer,
  368. const FuncGraphPtr &root_graph) {
  369. bool change = false;
  370. // OrderedMap<AnfNodePtr, OrderedSet<size_t>> call_unused_indexes;
  371. HashMap<FuncGraphPtr, std::shared_ptr<HashSet<size_t>>> func_graphs_used_indexes;
  372. OrderedMap<CNodePtr, HashSet<size_t>> visited_args;
  373. // Visit all args of all calls.
  374. for (const auto &call : all_calls) {
  375. // Get unused indexes of call node.
  376. auto closures = analyzer.GetCallClosures(call);
  377. for (const auto &closure : closures) {
  378. std::shared_ptr<HashSet<size_t>> cur_fg_used_indexes;
  379. auto it = func_graphs_used_indexes.find(closure->func_graph_);
  380. if (it != func_graphs_used_indexes.end()) {
  381. cur_fg_used_indexes = it->second;
  382. } else {
  383. // Get unused parameter indexes of graph.
  384. cur_fg_used_indexes = GetUsedParameters(closure->func_graph_);
  385. func_graphs_used_indexes[closure->func_graph_] = cur_fg_used_indexes;
  386. }
  387. VisitClosureArg(closure, call->cast<CNodePtr>(), *cur_fg_used_indexes, &visited_args);
  388. }
  389. }
  390. // Erase unvisited args.
  391. for (const auto &[arg_user, visit_indexes] : visited_args) {
  392. for (size_t i = 0; i < arg_user->inputs().size(); i++) {
  393. if (visit_indexes.find(i) == visit_indexes.end()) {
  394. change = EraseArg(i, arg_user, root_graph->manager()) || change;
  395. }
  396. }
  397. }
  398. return change;
  399. }
  400. // Visit graphs by DFS.
  401. void VisitGraph(const FuncGraphPtr &func_graph,
  402. const OrderedMap<FuncGraphPtr, OrderedSet<FuncGraphPtr>> &graph_relations,
  403. HashSet<FuncGraphPtr> *visited_graphs) {
  404. (void)visited_graphs->insert(func_graph);
  405. auto it = graph_relations.find(func_graph);
  406. if (it == graph_relations.end()) {
  407. return;
  408. }
  409. const auto &sub_graphs = it->second;
  410. for (const auto &sub_graph : sub_graphs) {
  411. if (visited_graphs->find(sub_graph) != visited_graphs->end()) {
  412. continue;
  413. }
  414. MS_LOG(DEBUG) << "Visit from graph: " << func_graph->ToString() << " to graph: " << sub_graph->ToString();
  415. VisitGraph(sub_graph, graph_relations, visited_graphs);
  416. }
  417. }
  418. bool EraseGraphCaller(const FuncGraphPtr &func_graph, const FuncGraphAnalyzer &analyzer,
  419. const FuncGraphManagerPtr &manager) {
  420. const auto &calls = analyzer.GetFuncGraphCallers(func_graph);
  421. bool change = false;
  422. for (const auto &call : calls) {
  423. auto call_closures = analyzer.GetCallClosures(call);
  424. // In fact, we can remove the arg user here, but in order to keep dead node eliminating strategy common, we consider
  425. // dead node only come from make tuple's input and caller's arg, so we erase the arg(which is input of arg user)
  426. // instead of arg user here.
  427. for (const auto &closure : call_closures) {
  428. for (size_t i = 0; i < closure->arg_users_.size(); i++) {
  429. EraseArg(closure->arg_indexes_[i], closure->arg_users_[i], manager);
  430. }
  431. }
  432. change = true;
  433. }
  434. return change;
  435. }
  436. std::shared_ptr<OrderedMap<FuncGraphPtr, OrderedSet<FuncGraphPtr>>> GetGraphRelations(
  437. const OrderedSet<FuncGraphPtr> &all_graphs, const FuncGraphAnalyzer &analyzer, const FuncGraphManagerPtr &manager) {
  438. auto graph_relations = std::make_shared<OrderedMap<FuncGraphPtr, OrderedSet<FuncGraphPtr>>>();
  439. for (const auto &func_graph : all_graphs) {
  440. const auto &graph_callers = analyzer.GetFuncGraphCallers(func_graph);
  441. for (const auto &caller : graph_callers) {
  442. // If call exist in graph.
  443. if (manager->all_nodes().find(caller) != manager->all_nodes().end()) {
  444. (*graph_relations)[caller->func_graph()].insert(func_graph);
  445. }
  446. }
  447. }
  448. return graph_relations;
  449. }
  450. bool EraseCircleGraphs(const FuncGraphPtr &root_graph, const FuncGraphAnalyzer &analyzer,
  451. const OrderedMap<FuncGraphPtr, OrderedSet<FuncGraphPtr>> &graph_relations) {
  452. HashSet<FuncGraphPtr> visited_graphs;
  453. VisitGraph(root_graph, graph_relations, &visited_graphs);
  454. bool change = false;
  455. // Eliminate unvisited graph's caller
  456. for (const auto &it : graph_relations) {
  457. const auto graph = it.first;
  458. if (graph->manager() == nullptr) {
  459. continue;
  460. }
  461. if (visited_graphs.find(graph) == visited_graphs.end()) {
  462. MS_LOG(WARNING) << "Erase unvisited graph: " << graph->ToString();
  463. change = EraseGraphCaller(graph, analyzer, graph->manager()) || change;
  464. }
  465. }
  466. return change;
  467. }
  468. std::shared_ptr<HashMap<std::string, std::vector<AnfNodePtr>>> SearchVisitNodes(const FuncGraphPtr &func_graph) {
  469. auto ret = std::make_shared<HashMap<std::string, std::vector<AnfNodePtr>>>();
  470. const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
  471. for (const auto &node : all_nodes) {
  472. if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
  473. (*ret)["tuple_getitem"].emplace_back(node);
  474. } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
  475. (*ret)["make_tuple"].emplace_back(node);
  476. } else if (IsValueNode<ValueTuple>(node)) {
  477. (*ret)["value_tuple"].emplace_back(node);
  478. } else if (IsValueNode<FuncGraph>(node)) {
  479. (*ret)["graph_value_node"].emplace_back(node);
  480. } else if (IsFuncGraphCallNode(node)) {
  481. (*ret)["func_graph_call"].emplace_back(node);
  482. }
  483. }
  484. return ret;
  485. }
  486. std::shared_ptr<OrderedSet<FuncGraphPtr>> GetAllFuncGraphs(const std::vector<AnfNodePtr> &value_nodes) {
  487. auto func_graphs = std::make_shared<OrderedSet<FuncGraphPtr>>();
  488. std::for_each(value_nodes.begin(), value_nodes.end(),
  489. [&func_graphs](const AnfNodePtr &node) { func_graphs->insert(GetValueNode<FuncGraphPtr>(node)); });
  490. return func_graphs;
  491. }
  492. bool EliminateDeadNode(const FuncGraphPtr &func_graph) {
  493. // Travers all tuple getitem nodes to visit.
  494. FuncGraphAnalyzer analyzer(func_graph);
  495. analyzer.Run();
  496. // Don't handle no-incorporate-call situation to improve performance.
  497. if (!analyzer.HasIncorporateCall()) {
  498. return false;
  499. }
  500. bool change = false;
  501. bool cycle_change = true;
  502. while (cycle_change) {
  503. cycle_change = false;
  504. ContextManager context_manager;
  505. auto visited_nodes = SearchVisitNodes(func_graph);
  506. auto seen = NewSeenGeneration();
  507. std::vector<int64_t> index_stack;
  508. // Visit from all tuple_getitem.
  509. for (const auto &tuple_getitem : (*visited_nodes)["tuple_getitem"]) {
  510. VisitNode(tuple_getitem, analyzer, index_stack, seen, &context_manager);
  511. }
  512. // Visit from root graph output.
  513. const auto &output_getitems = GenerateOutputTempGetItems(func_graph);
  514. for (const auto &tuple_getitem : output_getitems) {
  515. VisitNode(tuple_getitem, analyzer, index_stack, seen, &context_manager);
  516. }
  517. // 1. Erase all make tuple's unused input.
  518. cycle_change =
  519. EraseMakeTupleInput((*visited_nodes)["make_tuple"], func_graph, context_manager, seen) || cycle_change;
  520. // 2. Erase all value tuple's dead values.
  521. cycle_change = EraseDeadValues((*visited_nodes)["value_tuple"], context_manager) || cycle_change;
  522. // 3. Erase unused parameter's arg.
  523. const auto &all_func_graph_calls = (*visited_nodes)["func_graph_call"];
  524. cycle_change = EraseUnusedArgs(all_func_graph_calls, analyzer, func_graph) || cycle_change;
  525. // 4. Erase circle closures's all caller arg.
  526. // Erase circle graphs: caller[fg1] = fg2, caller[fg2] = fg1, fg1 and fg2 are redundant.
  527. auto all_graphs = GetAllFuncGraphs((*visited_nodes)["graph_value_node"]);
  528. auto graph_relations = GetGraphRelations(*all_graphs, analyzer, func_graph->manager());
  529. cycle_change = EraseCircleGraphs(func_graph, analyzer, *graph_relations) || cycle_change;
  530. change = change || cycle_change;
  531. }
  532. return change;
  533. }
  534. } // namespace opt
  535. } // namespace mindspore