You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inline.h 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_
  17. #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_
  18. #include <vector>
  19. #include <utility>
  20. #include <algorithm>
  21. #include "frontend/optimizer/irpass.h"
  22. #include "frontend/optimizer/optimizer.h"
  23. #include "frontend/optimizer/anf_visitor.h"
  24. #include "ir/func_graph.h"
  25. #include "ir/func_graph_cloner.h"
  26. #include "frontend/operator/ops.h"
  27. namespace mindspore {
  28. namespace opt {
  29. namespace irpass {
  30. class ReplaceApplicator : public AnfVisitor {
  31. public:
  32. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  33. if (!IsValueNode<FuncGraph>(node)) {
  34. return nullptr;
  35. }
  36. auto fg = GetValueNode<FuncGraphPtr>(node);
  37. if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub() || *(fg->switch_layer_input())) {
  38. return nullptr;
  39. }
  40. auto out = fg->output();
  41. MS_EXCEPTION_IF_NULL(out);
  42. if (!out->isa<CNode>()) {
  43. return nullptr;
  44. }
  45. auto &inputs = out->cast<CNodePtr>()->inputs();
  46. auto params = fg->parameters();
  47. // Exclude first elements of inputs which is fn.
  48. auto input_size = inputs.size();
  49. auto param_size = params.size();
  50. if ((input_size == 1 && param_size == 0) || (input_size > 1 && (input_size - 1) == param_size &&
  51. std::equal(inputs.begin() + 1, inputs.end(), params.begin()))) {
  52. auto inner = inputs[0];
  53. if (IsValueNode<Primitive>(inner) ||
  54. (IsValueNode<FuncGraph>(inner) && GetValueNode<FuncGraphPtr>(inner)->parent() == nullptr)) {
  55. return inner;
  56. }
  57. }
  58. return nullptr;
  59. }
  60. };
  61. using CriterionFuncType = std::function<bool(FuncGraphPtr, AnfNodePtr)>;
  62. bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) {
  63. auto n_cnode = fg->nodes().size() - fg->parameters().size();
  64. // There is at least one CNode(return, other_node).
  65. return n_cnode <= 2;
  66. }
  67. bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) {
  68. auto &cnodes = fg->func_graph_cnodes_index();
  69. int n_use =
  70. std::accumulate(cnodes.begin(), cnodes.end(), 0,
  71. [](int sum, const std::pair<const CNodeIndexPairPtr, int> &item) { return sum + item.second; });
  72. return n_use == 1;
  73. }
  74. bool IsInside(FuncGraphPtr, const AnfNodePtr &node) {
  75. MS_EXCEPTION_IF_NULL(node->func_graph());
  76. return node->func_graph()->has_flag("inline_inside");
  77. }
  78. bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); }
  79. bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; }
  80. bool IsDirectParentCall(FuncGraphPtr fg, AnfNodePtr node) {
  81. bool unique_use = IsUniqueUse(fg, nullptr);
  82. bool is_recursive = fg->recursive();
  83. if (fg->parent() != nullptr && is_recursive) {
  84. if (fg->parent() == node->func_graph() && unique_use) {
  85. return true;
  86. }
  87. }
  88. return false;
  89. }
  90. // {G, Xs}
  91. class InlinerBase : public AnfVisitor {
  92. public:
  93. explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions, bool use_move = true)
  94. : use_move_(use_move), criterions_(criterions) {}
  95. ~InlinerBase() override = default;
  96. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  97. if (!node->isa<CNode>()) {
  98. return nullptr;
  99. }
  100. auto &inputs = node->cast<CNodePtr>()->inputs();
  101. if (inputs.size() < 1 || !IsValueNode<FuncGraph>(inputs[0])) {
  102. return nullptr;
  103. }
  104. // G
  105. auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
  106. if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
  107. return nullptr;
  108. }
  109. // Do not inline GraphKernel to Cell.
  110. if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
  111. // If the GraphKernel only contains a return node, we make it inlined.
  112. if (fg->nodes().size() - fg->parameters().size() > 1) {
  113. return nullptr;
  114. }
  115. }
  116. Reset();
  117. bool is_match = false;
  118. for (auto &criterion : criterions_) {
  119. if (!criterion.first(fg, node)) {
  120. continue;
  121. }
  122. if (criterion.second && IsRecursive(fg)) {
  123. continue;
  124. }
  125. is_match = true;
  126. break;
  127. }
  128. if (!is_match) {
  129. return nullptr;
  130. }
  131. std::vector<AnfNodePtr> params;
  132. (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params));
  133. // compare size to avoid the case that the function has default value after grad.
  134. // for which after renormalize, the function default value will be an input
  135. if (fg->parameters().size() != params.size()) {
  136. return nullptr;
  137. }
  138. if (use_move_ && IsUniqueUse(fg, nullptr)) {
  139. auto mng = fg->manager();
  140. MS_EXCEPTION_IF_NULL(mng);
  141. ReplaceParams(mng, params, fg);
  142. auto out_node = fg->output();
  143. mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
  144. return out_node;
  145. }
  146. return InlineClone(fg, node->func_graph(), params, inputs[0]->scope());
  147. }
  148. void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector<AnfNodePtr> &new_params,
  149. const FuncGraphPtr &fg) {
  150. auto params = fg->parameters();
  151. auto old_size = params.size();
  152. if (old_size != new_params.size()) {
  153. MS_LOG(EXCEPTION) << "Parameter size not match." << old_size << " new " << new_params.size()
  154. << fg->output()->DebugString(10);
  155. }
  156. for (size_t i = 0; i < old_size; i++) {
  157. (void)mng->Replace(params[i], new_params[i]);
  158. }
  159. }
  160. bool IsRecursive(const FuncGraphPtr &fg) {
  161. if (!is_checked_) {
  162. is_checked_ = true;
  163. is_recursive_ = fg->recursive();
  164. }
  165. return is_recursive_;
  166. }
  167. void Reset() {
  168. is_checked_ = false;
  169. is_recursive_ = false;
  170. }
  171. private:
  172. bool is_checked_{false}, is_recursive_{false};
  173. bool use_move_;
  174. std::vector<std::pair<CriterionFuncType, bool>> criterions_;
  175. };
  176. class Inliner : public InlinerBase {
  177. public:
  178. explicit Inliner(bool use_move = true)
  179. : InlinerBase(
  180. {
  181. {IsUniqueUse, true},
  182. {IsTrivial, false},
  183. {IsInside, false},
  184. {IsCore, false},
  185. {IsDirectParentCall, false},
  186. {NoCriterion, true},
  187. },
  188. use_move) {}
  189. ~Inliner() override = default;
  190. };
  191. class DirectInliner : public InlinerBase {
  192. public:
  193. explicit DirectInliner(bool use_move = true)
  194. : InlinerBase(
  195. {
  196. {IsDirectParentCall, false},
  197. },
  198. use_move) {}
  199. ~DirectInliner() override = default;
  200. };
  201. } // namespace irpass
  202. } // namespace opt
  203. } // namespace mindspore
  204. #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_