| @@ -70,46 +70,21 @@ class ReplaceApplicator : public AnfVisitor { | |||
| } | |||
| }; | |||
| using CriterionFuncType = std::function<bool(FuncGraphPtr, AnfNodePtr)>; | |||
| class InlinerBase; | |||
| using CriterionFuncType = std::function<bool(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &)>; | |||
| bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) { | |||
| auto n_cnode = fg->nodes().size() - fg->parameters().size(); | |||
| // There is at least one CNode(return, other_node). | |||
| return n_cnode <= 2; | |||
| } | |||
| bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) { | |||
| auto &cnodes = fg->func_graph_cnodes_index(); | |||
| int64_t n_use = std::accumulate( | |||
| cnodes.begin(), cnodes.end(), 0, | |||
| [](int64_t sum, const std::pair<const CNodeIndexPairPtr, int64_t> &item) { return sum + item.second; }); | |||
| return n_use == 1; | |||
| } | |||
| bool IsInside(FuncGraphPtr, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | |||
| return node->func_graph()->has_flag("inline_inside"); | |||
| } | |||
| bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); } | |||
| bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; } | |||
| bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &); | |||
| bool IsDirectParentCall(FuncGraphPtr fg, AnfNodePtr node) { | |||
| bool unique_use = IsUniqueUse(fg, nullptr); | |||
| bool is_recursive = fg->recursive(); | |||
| if (fg->parent() != nullptr && is_recursive) { | |||
| if (fg->parent() == node->func_graph() && unique_use) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &); | |||
| bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node); | |||
| bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &); | |||
| bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node); | |||
| bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &); | |||
| // {G, Xs} | |||
| class InlinerBase : public AnfVisitor { | |||
| public: | |||
| explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions, bool use_move = true) | |||
| explicit InlinerBase(std::vector<std::vector<CriterionFuncType>> criterions, bool use_move = true) | |||
| : use_move_(use_move), criterions_(criterions) {} | |||
| ~InlinerBase() override = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| @@ -135,22 +110,24 @@ class InlinerBase : public AnfVisitor { | |||
| return nullptr; | |||
| } | |||
| } | |||
| Reset(); | |||
| // 'criterions_': {criterion_group_1:{criterion1, criterion2, ...}, criterion_group_2:{...}, ...} | |||
| // All the criterions of 'criterion group' are true would set 'criterion group' as 'true'. As [AND]. | |||
| // Anyone of 'criterion group' in 'criterions_' is 'true' would be matched. As [OR]. | |||
| bool is_match = false; | |||
| for (auto &criterion : criterions_) { | |||
| if (!criterion.first(fg, node)) { | |||
| continue; | |||
| for (auto &criterions : criterions_) { // Each 'criterion group' in criterions_. | |||
| is_match = true; | |||
| for (auto &criterion : criterions) { // Each criterion in 'criterion group'. | |||
| if (!criterion(this, fg, node)) { | |||
| is_match = false; | |||
| break; | |||
| } | |||
| } | |||
| if (criterion.second && IsRecursive(fg)) { | |||
| continue; | |||
| if (is_match) { | |||
| break; | |||
| } | |||
| is_match = true; | |||
| break; | |||
| } | |||
| if (!is_match) { | |||
| return nullptr; | |||
| } | |||
| @@ -162,24 +139,19 @@ class InlinerBase : public AnfVisitor { | |||
| if (fg->parameters().size() != args.size()) { | |||
| return nullptr; | |||
| } | |||
| auto is_unique_use = IsUniqueUse(fg, nullptr); | |||
| // Not to inline after block if it has switch call inside, to avoid switch expansion. | |||
| if (!is_unique_use && fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) { | |||
| auto has_branch_call = GraphHasBranch(fg); | |||
| if (has_branch_call) { | |||
| return TransformBranchCall(fg, node, args); | |||
| if (IsUniqueUse(nullptr, fg, nullptr)) { | |||
| if (use_move_) { | |||
| auto mng = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| ReplaceParams(mng, args, fg); | |||
| auto out_node = fg->output(); | |||
| mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope()); | |||
| return out_node; | |||
| } | |||
| } else if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK) && GraphHasBranch(fg)) { | |||
| // Not to inline after block if it has switch call inside, to avoid switch expansion. | |||
| return TransformBranchCall(fg, node, args); | |||
| } | |||
| if (use_move_ && is_unique_use) { | |||
| auto mng = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| ReplaceParams(mng, args, fg); | |||
| auto out_node = fg->output(); | |||
| mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope()); | |||
| return out_node; | |||
| } | |||
| return InlineClone(fg, node->func_graph(), args, inputs[0]->scope()); | |||
| } | |||
| @@ -208,6 +180,7 @@ class InlinerBase : public AnfVisitor { | |||
| is_checked_ = false; | |||
| is_recursive_ = false; | |||
| } | |||
| // For after block which contains branch call, delete the parameters which is not used. | |||
| // In most cases, it may be a `Module` or other constant input. | |||
| AnfNodePtr TransformBranchCall(const FuncGraphPtr &fg, const AnfNodePtr &node, const std::vector<AnfNodePtr> &args) { | |||
| @@ -298,25 +271,62 @@ class InlinerBase : public AnfVisitor { | |||
| private: | |||
| bool is_checked_{false}, is_recursive_{false}; | |||
| bool use_move_; | |||
| std::vector<std::pair<CriterionFuncType, bool>> criterions_; | |||
| std::vector<std::vector<CriterionFuncType>> criterions_; | |||
| std::unordered_map<FuncGraphPtr, bool> graph_branch_cache_; | |||
| // Key is the old func graph, and the value is the new func_graph | |||
| std::unordered_map<FuncGraphPtr, FuncGraphPtr> transformed_branch_chache_; | |||
| }; | |||
| bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { | |||
| auto &cnodes = fg->func_graph_cnodes_index(); | |||
| int64_t n_use = std::accumulate( | |||
| cnodes.begin(), cnodes.end(), 0, | |||
| [](int64_t sum, const std::pair<const CNodeIndexPairPtr, int64_t> &item) { return sum + item.second; }); | |||
| return n_use == 1; | |||
| } | |||
| bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { | |||
| auto n_cnode = fg->nodes().size() - fg->parameters().size(); | |||
| // There is at least one CNode(return, other_node). | |||
| return n_cnode <= 2; | |||
| } | |||
| bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | |||
| return node->func_graph()->has_flag("inline_inside"); | |||
| } | |||
| bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { return fg->has_flag("core"); } | |||
| bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node) { | |||
| bool unique_use = IsUniqueUse(nullptr, fg, nullptr); | |||
| bool is_recursive = fg->recursive(); | |||
| if (fg->parent() != nullptr && is_recursive) { | |||
| if (fg->parent() == node->func_graph() && unique_use) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &) { | |||
| return !inliner->IsRecursive(fg); | |||
| } | |||
| class Inliner : public InlinerBase { | |||
| public: | |||
| explicit Inliner(bool use_move = true) | |||
| : InlinerBase( | |||
| // Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}. | |||
| { | |||
| {IsUniqueUse, true}, | |||
| {IsTrivial, false}, | |||
| {IsInside, false}, | |||
| {IsCore, false}, | |||
| {IsDirectParentCall, false}, | |||
| {NoCriterion, true}, | |||
| {IsTrivial}, | |||
| {IsInside}, | |||
| {IsCore}, | |||
| {IsNotRecursive}, | |||
| {IsDirectParentCall}, | |||
| }, | |||
| use_move) {} | |||
| ~Inliner() override = default; | |||
| }; | |||
| @@ -324,8 +334,9 @@ class DirectInliner : public InlinerBase { | |||
| public: | |||
| explicit DirectInliner(bool use_move = true) | |||
| : InlinerBase( | |||
| // Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}. | |||
| { | |||
| {IsDirectParentCall, false}, | |||
| {IsDirectParentCall}, | |||
| }, | |||
| use_move) {} | |||
| ~DirectInliner() override = default; | |||