You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parameter_eliminate.h 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H
  17. #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H
  18. #include <vector>
  19. #include <utility>
  20. #include <memory>
  21. #include "utils/hash_set.h"
  22. #include "frontend/optimizer/irpass.h"
  23. #include "frontend/optimizer/optimizer.h"
  24. #include "frontend/optimizer/anf_visitor.h"
  25. #include "ir/manager.h"
  26. #include "ir/func_graph.h"
  27. #include "frontend/operator/ops.h"
  28. namespace mindspore {
  29. namespace opt {
  30. namespace irpass {
  31. class ParameterEliminator {
  32. public:
  33. ParameterEliminator() = default;
  34. virtual ~ParameterEliminator() = default;
  35. bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
  36. const auto &manager = func_graph->manager();
  37. MS_EXCEPTION_IF_NULL(manager);
  38. bool changes = false;
  39. while (true) {
  40. const auto &[fg, callers] = SearchFuncGraphCallers(func_graph);
  41. if (fg == nullptr) {
  42. break;
  43. }
  44. auto manager = fg->manager();
  45. MS_EXCEPTION_IF_NULL(manager);
  46. const auto &erase_indexes = EraseUnusedParameters(fg, manager);
  47. for (auto caller : callers) {
  48. // Erase the corresponding args.
  49. EraseArgs(caller, erase_indexes, manager);
  50. }
  51. changes = true;
  52. }
  53. return changes;
  54. }
  55. private:
  56. static std::vector<CNodePtr> GetCallers(const FuncGraphPtr &fg) {
  57. const auto &fg_caller_and_indexes = fg->func_graph_cnodes_index();
  58. std::vector<CNodePtr> caller_cnodes = {};
  59. // Find all caller of fg.
  60. for (const auto &it : fg_caller_and_indexes) {
  61. const auto &fg_caller_and_index = it.first;
  62. auto caller_cnode = fg_caller_and_index->first;
  63. auto index = fg_caller_and_index->second;
  64. // If index != 0, the caller is a indirect caller, can't erase the parameter of graph.Because
  65. // in this situation ValueNode<FuncGraph> is a input of Return or of MakeTuple.
  66. if (index != 0) {
  67. return {};
  68. }
  69. caller_cnodes.push_back(caller_cnode->cast<CNodePtr>());
  70. }
  71. return caller_cnodes;
  72. }
  73. static std::pair<FuncGraphPtr, std::vector<CNodePtr>> SearchFuncGraphCallers(const FuncGraphPtr &func_graph) {
  74. for (const auto &fg : func_graph->func_graphs_used_total()) {
  75. if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
  76. continue;
  77. }
  78. const auto &parameters = fg->parameters();
  79. MS_EXCEPTION_IF_NULL(fg->manager());
  80. const auto &manager_node_users = fg->manager()->node_users();
  81. bool exist_param_unused =
  82. std::any_of(parameters.begin(), parameters.end(), [&manager_node_users](const AnfNodePtr &parameter) {
  83. const auto &node_users_it = manager_node_users.find(parameter);
  84. return node_users_it == manager_node_users.end() || node_users_it->second.empty();
  85. });
  86. if (exist_param_unused) {
  87. const auto &callers = GetCallers(fg);
  88. if (!callers.empty()) {
  89. return {fg, callers};
  90. }
  91. }
  92. }
  93. return {nullptr, {}};
  94. }
  95. static mindspore::HashSet<size_t> EraseUnusedParameters(const FuncGraphPtr &fg, const FuncGraphManagerPtr &manager) {
  96. MS_EXCEPTION_IF_NULL(fg->manager());
  97. const auto &manager_node_users = fg->manager()->node_users();
  98. const auto &parameters = fg->parameters();
  99. mindspore::HashSet<size_t> unused_parameter_indexes;
  100. // Traverse to find all unused parameters.
  101. size_t index = 0;
  102. for (const auto &parameter : parameters) {
  103. const auto &node_users_it = manager_node_users.find(parameter);
  104. if (node_users_it == manager_node_users.end() || node_users_it->second.empty()) {
  105. unused_parameter_indexes.insert(index);
  106. }
  107. index++;
  108. }
  109. // Erase unused parameters.
  110. std::vector<AnfNodePtr> new_parameters;
  111. for (size_t i = 0; i < parameters.size(); i++) {
  112. if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) {
  113. new_parameters.push_back(parameters[i]);
  114. } else {
  115. MS_LOG(DEBUG) << "Erase parameter:" << parameters[i]->DebugString() << ",index:" << i;
  116. }
  117. }
  118. manager->SetParameters(fg, new_parameters);
  119. return unused_parameter_indexes;
  120. }
  121. static void EraseArgs(const CNodePtr &caller, const mindspore::HashSet<size_t> &unused_parameter_indexes,
  122. const FuncGraphManagerPtr &manager) {
  123. std::vector<AnfNodePtr> new_args = {caller->inputs()[0]};
  124. for (size_t i = 0; i < caller->inputs().size() - 1; i++) {
  125. if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) {
  126. new_args.push_back(caller->inputs()[i + 1]);
  127. } else {
  128. MS_LOG(DEBUG) << "Erase arg:" << caller->inputs()[i + 1]->DebugString() << ",index:" << i;
  129. }
  130. }
  131. TraceGuard trace_guard(std::make_shared<TraceCopy>(caller->debug_info()));
  132. auto new_caller = caller->func_graph()->NewCNode(new_args);
  133. new_caller->set_abstract(caller->abstract());
  134. manager->Replace(caller, new_caller);
  135. }
  136. };
  137. } // namespace irpass
  138. } // namespace opt
  139. } // namespace mindspore
  140. #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H