diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h index df28004ae9..f893f5df93 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h @@ -70,46 +70,21 @@ class ReplaceApplicator : public AnfVisitor { } }; -using CriterionFuncType = std::function; +class InlinerBase; +using CriterionFuncType = std::function; -bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) { - auto n_cnode = fg->nodes().size() - fg->parameters().size(); - // There is at least one CNode(return, other_node). - return n_cnode <= 2; -} - -bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) { - auto &cnodes = fg->func_graph_cnodes_index(); - int64_t n_use = std::accumulate( - cnodes.begin(), cnodes.end(), 0, - [](int64_t sum, const std::pair &item) { return sum + item.second; }); - return n_use == 1; -} - -bool IsInside(FuncGraphPtr, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node->func_graph()); - return node->func_graph()->has_flag("inline_inside"); -} - -bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); } - -bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; } +bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &); -bool IsDirectParentCall(FuncGraphPtr fg, AnfNodePtr node) { - bool unique_use = IsUniqueUse(fg, nullptr); - bool is_recursive = fg->recursive(); - if (fg->parent() != nullptr && is_recursive) { - if (fg->parent() == node->func_graph() && unique_use) { - return true; - } - } - return false; -} +bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &); +bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node); +bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &); +bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node); +bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &); // {G, Xs} class InlinerBase : public AnfVisitor { public: - explicit InlinerBase(std::vector> criterions, bool use_move = true) + explicit InlinerBase(std::vector> criterions, bool use_move = true) : use_move_(use_move), criterions_(criterions) {} ~InlinerBase() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { @@ -135,22 +110,24 @@ class InlinerBase : public AnfVisitor { return nullptr; } } - Reset(); + + // 'criterions_': {criterion_group_1:{criterion1, criterion2, ...}, criterion_group_2:{...}, ...} + // All the criterions of 'criterion group' are true would set 'criterion group' as 'true'. As [AND]. + // Anyone of 'criterion group' in 'criterions_' is 'true' would be matched. As [OR]. bool is_match = false; - for (auto &criterion : criterions_) { - if (!criterion.first(fg, node)) { - continue; + for (auto &criterions : criterions_) { // Each 'criterion group' in criterions_. + is_match = true; + for (auto &criterion : criterions) { // Each criterion in 'criterion group'. + if (!criterion(this, fg, node)) { + is_match = false; + break; + } } - - if (criterion.second && IsRecursive(fg)) { - continue; + if (is_match) { + break; } - - is_match = true; - break; } - if (!is_match) { return nullptr; } @@ -162,24 +139,19 @@ class InlinerBase : public AnfVisitor { if (fg->parameters().size() != args.size()) { return nullptr; } - auto is_unique_use = IsUniqueUse(fg, nullptr); - // Not to inline after block if it has switch call inside, to avoid switch expansion. - if (!is_unique_use && fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) { - auto has_branch_call = GraphHasBranch(fg); - if (has_branch_call) { - return TransformBranchCall(fg, node, args); + if (IsUniqueUse(nullptr, fg, nullptr)) { + if (use_move_) { + auto mng = fg->manager(); + MS_EXCEPTION_IF_NULL(mng); + ReplaceParams(mng, args, fg); + auto out_node = fg->output(); + mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope()); + return out_node; } + } else if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK) && GraphHasBranch(fg)) { + // Not to inline after block if it has switch call inside, to avoid switch expansion. + return TransformBranchCall(fg, node, args); } - - if (use_move_ && is_unique_use) { - auto mng = fg->manager(); - MS_EXCEPTION_IF_NULL(mng); - ReplaceParams(mng, args, fg); - auto out_node = fg->output(); - mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope()); - return out_node; - } - return InlineClone(fg, node->func_graph(), args, inputs[0]->scope()); } @@ -208,6 +180,7 @@ class InlinerBase : public AnfVisitor { is_checked_ = false; is_recursive_ = false; } + // For after block which contains branch call, delete the parameters which is not used. // In most cases, it may be a `Module` or other constant input. AnfNodePtr TransformBranchCall(const FuncGraphPtr &fg, const AnfNodePtr &node, const std::vector &args) { @@ -298,25 +271,62 @@ class InlinerBase : public AnfVisitor { private: bool is_checked_{false}, is_recursive_{false}; bool use_move_; - std::vector> criterions_; + std::vector> criterions_; std::unordered_map graph_branch_cache_; // Key is the old func graph, and the value is the new func_graph std::unordered_map transformed_branch_chache_; }; +bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { + auto &cnodes = fg->func_graph_cnodes_index(); + int64_t n_use = std::accumulate( + cnodes.begin(), cnodes.end(), 0, + [](int64_t sum, const std::pair &item) { return sum + item.second; }); + return n_use == 1; +} + +bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { + auto n_cnode = fg->nodes().size() - fg->parameters().size(); + // There is at least one CNode(return, other_node). + return n_cnode <= 2; +} + +bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node->func_graph()); + return node->func_graph()->has_flag("inline_inside"); +} + +bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { return fg->has_flag("core"); } + +bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node) { + bool unique_use = IsUniqueUse(nullptr, fg, nullptr); + bool is_recursive = fg->recursive(); + if (fg->parent() != nullptr && is_recursive) { + if (fg->parent() == node->func_graph() && unique_use) { + return true; + } + } + return false; +} + +bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &) { + return !inliner->IsRecursive(fg); +} + class Inliner : public InlinerBase { public: explicit Inliner(bool use_move = true) : InlinerBase( + // Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}. { - {IsUniqueUse, true}, - {IsTrivial, false}, - {IsInside, false}, - {IsCore, false}, - {IsDirectParentCall, false}, - {NoCriterion, true}, + {IsTrivial}, + {IsInside}, + {IsCore}, + {IsNotRecursive}, + {IsDirectParentCall}, }, use_move) {} + ~Inliner() override = default; }; @@ -324,8 +334,9 @@ class DirectInliner : public InlinerBase { public: explicit DirectInliner(bool use_move = true) : InlinerBase( + // Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}. { - {IsDirectParentCall, false}, + {IsDirectParentCall}, }, use_move) {} ~DirectInliner() override = default;