You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inline.h 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_
  17. #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_
  18. #include <vector>
  19. #include <utility>
  20. #include <algorithm>
  21. #include "optimizer/irpass.h"
  22. #include "optimizer/optimizer.h"
  23. #include "ir/visitor.h"
  24. #include "ir/func_graph.h"
  25. #include "ir/func_graph_cloner.h"
  26. #include "operator/ops.h"
  27. namespace mindspore {
  28. namespace opt {
  29. namespace irpass {
  30. class ReplaceApplicator : public AnfVisitor {
  31. public:
  32. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  33. if (!IsValueNode<FuncGraph>(node)) {
  34. return nullptr;
  35. }
  36. auto fg = GetValueNode<FuncGraphPtr>(node);
  37. if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
  38. return nullptr;
  39. }
  40. auto out = fg->output();
  41. MS_EXCEPTION_IF_NULL(out);
  42. if (!out->isa<CNode>()) {
  43. return nullptr;
  44. }
  45. auto &inputs = out->cast<CNodePtr>()->inputs();
  46. auto params = fg->parameters();
  47. // Exclude first elements of inputs which is fn.
  48. auto input_size = inputs.size();
  49. auto param_size = params.size();
  50. if ((input_size == 1 && param_size == 0) || (input_size > 1 && (input_size - 1) == param_size &&
  51. std::equal(inputs.begin() + 1, inputs.end(), params.begin()))) {
  52. auto inner = inputs[0];
  53. if (IsValueNode<Primitive>(inner) ||
  54. (IsValueNode<FuncGraph>(inner) && GetValueNode<FuncGraphPtr>(inner)->parent() == nullptr)) {
  55. return inner;
  56. }
  57. }
  58. return nullptr;
  59. }
  60. };
  61. using CriterionFuncType = std::function<bool(FuncGraphPtr, AnfNodePtr)>;
  62. bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) {
  63. auto n_cnode = fg->nodes().size() - fg->parameters().size();
  64. // There is at least one CNode(return, other_node).
  65. return n_cnode <= 2;
  66. }
  67. bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) {
  68. auto &cnodes = fg->func_graph_cnodes_index();
  69. int n_use =
  70. std::accumulate(cnodes.begin(), cnodes.end(), 0,
  71. [](int sum, const std::pair<const CNodeIndexPairPtr, int> &item) { return sum + item.second; });
  72. return n_use == 1;
  73. }
  74. bool IsInside(FuncGraphPtr, const AnfNodePtr &node) {
  75. MS_EXCEPTION_IF_NULL(node->func_graph());
  76. return node->func_graph()->has_flag("inline_inside");
  77. }
  78. bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); }
  79. bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; }
  80. // {G, Xs}
  81. class InlinerBase : public AnfVisitor {
  82. public:
  83. explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions) : criterions_(criterions) {}
  84. ~InlinerBase() override = default;
  85. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  86. if (!node->isa<CNode>()) {
  87. return nullptr;
  88. }
  89. auto &inputs = node->cast<CNodePtr>()->inputs();
  90. if (inputs.size() < 1 || !IsValueNode<FuncGraph>(inputs[0])) {
  91. return nullptr;
  92. }
  93. // G
  94. auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
  95. if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
  96. return nullptr;
  97. }
  98. // Do not inline GraphKernel to Cell.
  99. if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
  100. // If the GraphKernel only contains a return node, we make it inlined.
  101. if (fg->nodes().size() - fg->parameters().size() > 1) {
  102. return nullptr;
  103. }
  104. }
  105. Reset();
  106. bool is_match = false;
  107. for (auto &criterion : criterions_) {
  108. if (!criterion.first(fg, node)) {
  109. continue;
  110. }
  111. if (criterion.second && IsRecursive(fg)) {
  112. continue;
  113. }
  114. is_match = true;
  115. break;
  116. }
  117. if (!is_match) {
  118. return nullptr;
  119. }
  120. std::vector<AnfNodePtr> params;
  121. (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params));
  122. if (IsUniqueUse(fg, nullptr)) {
  123. auto mng = fg->manager();
  124. MS_EXCEPTION_IF_NULL(mng);
  125. ReplaceParams(mng, params, fg);
  126. auto out_node = fg->output();
  127. mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
  128. return out_node;
  129. }
  130. return InlineClone(fg, node->func_graph(), params, inputs[0]->scope());
  131. }
  132. void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector<AnfNodePtr> &new_params,
  133. const FuncGraphPtr &fg) {
  134. auto params = fg->parameters();
  135. auto old_size = params.size();
  136. if (old_size != new_params.size()) {
  137. MS_LOG(EXCEPTION) << "Parameter size not match." << old_size << " new " << new_params.size()
  138. << fg->output()->DebugString(10);
  139. }
  140. for (size_t i = 0; i < old_size; i++) {
  141. (void)mng->Replace(params[i], new_params[i]);
  142. }
  143. }
  144. bool IsRecursive(const FuncGraphPtr &fg) {
  145. if (!is_checked_) {
  146. is_checked_ = true;
  147. is_recursive_ = fg->recursive();
  148. }
  149. return is_recursive_;
  150. }
  151. void Reset() {
  152. is_checked_ = false;
  153. is_recursive_ = false;
  154. }
  155. private:
  156. bool is_checked_{false}, is_recursive_{false};
  157. std::vector<std::pair<CriterionFuncType, bool>> criterions_;
  158. };
  159. class Inliner : public InlinerBase {
  160. public:
  161. Inliner()
  162. : InlinerBase({
  163. {IsUniqueUse, true},
  164. {IsTrivial, false},
  165. {IsInside, false},
  166. {IsCore, false},
  167. {NoCriterion, true},
  168. }) {}
  169. ~Inliner() override = default;
  170. };
  171. } // namespace irpass
  172. } // namespace opt
  173. } // namespace mindspore
  174. #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_