You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inline.h 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_
  17. #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_
  18. #include <vector>
  19. #include <utility>
  20. #include <algorithm>
  21. #include "optimizer/irpass.h"
  22. #include "optimizer/optimizer.h"
  23. #include "ir/visitor.h"
  24. #include "ir/func_graph.h"
  25. #include "ir/func_graph_cloner.h"
  26. #include "operator/ops.h"
  27. namespace mindspore {
  28. namespace opt {
  29. namespace irpass {
  30. class ReplaceApplicator : public AnfVisitor {
  31. public:
  32. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  33. if (!IsValueNode<FuncGraph>(node)) {
  34. return nullptr;
  35. }
  36. auto fg = GetValueNode<FuncGraphPtr>(node);
  37. if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
  38. return nullptr;
  39. }
  40. auto out = fg->output();
  41. MS_EXCEPTION_IF_NULL(out);
  42. if (!out->isa<CNode>()) {
  43. return nullptr;
  44. }
  45. auto &inputs = out->cast<CNodePtr>()->inputs();
  46. auto params = fg->parameters();
  47. // Exclude first elements of inputs which is fn.
  48. auto input_size = inputs.size();
  49. auto param_size = params.size();
  50. if ((input_size == 1 && param_size == 0) || (input_size > 1 && (input_size - 1) == param_size &&
  51. std::equal(inputs.begin() + 1, inputs.end(), params.begin()))) {
  52. auto inner = inputs[0];
  53. if (IsValueNode<Primitive>(inner) ||
  54. (IsValueNode<FuncGraph>(inner) && GetValueNode<FuncGraphPtr>(inner)->parent() == nullptr)) {
  55. return inner;
  56. }
  57. }
  58. return nullptr;
  59. }
  60. };
  61. using CriterionFuncType = std::function<bool(FuncGraphPtr, AnfNodePtr)>;
  62. bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) {
  63. auto &s = fg->nodes();
  64. int n_cnode = std::count_if(s.begin(), s.end(), [](const AnfNodePtr &n) {
  65. MS_EXCEPTION_IF_NULL(n);
  66. return n->isa<CNode>();
  67. });
  68. // There is at least one CNode(return, other_node).
  69. return n_cnode <= 2;
  70. }
  71. bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) {
  72. auto &users = fg->func_graph_users();
  73. int n_use =
  74. std::accumulate(users.begin(), users.end(), 0,
  75. [](int sum, const std::pair<const FuncGraphPtr, int> &item) { return sum + item.second; });
  76. return n_use == 1;
  77. }
  78. bool IsInside(FuncGraphPtr, const AnfNodePtr &node) {
  79. MS_EXCEPTION_IF_NULL(node->func_graph());
  80. auto &flags = node->func_graph()->flags();
  81. if (flags.find("inline_inside") != flags.end()) {
  82. return flags["inline_inside"];
  83. }
  84. return false;
  85. }
  86. bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) {
  87. auto &flags = fg->flags();
  88. if (flags.find("core") != flags.end()) {
  89. return flags["core"];
  90. }
  91. return false;
  92. }
  93. bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; }
  94. // {G, Xs}
  95. class InlinerBase : public AnfVisitor {
  96. public:
  97. explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions) : criterions_(criterions) {}
  98. ~InlinerBase() override = default;
  99. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  100. if (!node->isa<CNode>()) {
  101. return nullptr;
  102. }
  103. auto &inputs = node->cast<CNodePtr>()->inputs();
  104. if (inputs.size() < 1 || !IsValueNode<FuncGraph>(inputs[0])) {
  105. return nullptr;
  106. }
  107. // G
  108. auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
  109. if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
  110. return nullptr;
  111. }
  112. Reset();
  113. bool is_match = false;
  114. for (auto &criterion : criterions_) {
  115. if (!criterion.first(fg, node)) {
  116. continue;
  117. }
  118. if (criterion.second && IsRecursive(fg)) {
  119. continue;
  120. }
  121. is_match = true;
  122. break;
  123. }
  124. if (!is_match) {
  125. return nullptr;
  126. }
  127. std::vector<AnfNodePtr> params;
  128. (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params));
  129. if (IsUniqueUse(fg, nullptr)) {
  130. auto mng = fg->manager();
  131. MS_EXCEPTION_IF_NULL(mng);
  132. ReplaceParams(mng, params, fg);
  133. auto out_node = fg->output();
  134. mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
  135. return out_node;
  136. }
  137. return InlineClone(fg, node->func_graph(), params, inputs[0]->scope());
  138. }
  139. void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector<AnfNodePtr> &new_params,
  140. const FuncGraphPtr &fg) {
  141. auto params = fg->parameters();
  142. auto old_size = params.size();
  143. if (old_size != new_params.size()) {
  144. MS_LOG(EXCEPTION) << "Parameter size not match.";
  145. }
  146. for (size_t i = 0; i < old_size; i++) {
  147. (void)mng->Replace(params[i], new_params[i]);
  148. }
  149. }
  150. bool IsRecursive(const FuncGraphPtr &fg) {
  151. if (!is_checked_) {
  152. is_checked_ = true;
  153. is_recursive_ = fg->recursive();
  154. }
  155. return is_recursive_;
  156. }
  157. void Reset() {
  158. is_checked_ = false;
  159. is_recursive_ = false;
  160. }
  161. private:
  162. bool is_checked_{false}, is_recursive_{false};
  163. std::vector<std::pair<CriterionFuncType, bool>> criterions_;
  164. };
  165. class Inliner : public InlinerBase {
  166. public:
  167. Inliner()
  168. : InlinerBase({
  169. {IsUniqueUse, true},
  170. {IsTrivial, false},
  171. {IsInside, false},
  172. {IsCore, false},
  173. {NoCriterion, true},
  174. }) {}
  175. ~Inliner() override = default;
  176. };
  177. } // namespace irpass
  178. } // namespace opt
  179. } // namespace mindspore
  180. #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_