|
|
|
@@ -70,46 +70,21 @@ class ReplaceApplicator : public AnfVisitor { |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
using CriterionFuncType = std::function<bool(FuncGraphPtr, AnfNodePtr)>; |
|
|
|
class InlinerBase; |
|
|
|
using CriterionFuncType = std::function<bool(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &)>; |
|
|
|
|
|
|
|
bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) { |
|
|
|
auto n_cnode = fg->nodes().size() - fg->parameters().size(); |
|
|
|
// There is at least one CNode(return, other_node). |
|
|
|
return n_cnode <= 2; |
|
|
|
} |
|
|
|
|
|
|
|
bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) { |
|
|
|
auto &cnodes = fg->func_graph_cnodes_index(); |
|
|
|
int64_t n_use = std::accumulate( |
|
|
|
cnodes.begin(), cnodes.end(), 0, |
|
|
|
[](int64_t sum, const std::pair<const CNodeIndexPairPtr, int64_t> &item) { return sum + item.second; }); |
|
|
|
return n_use == 1; |
|
|
|
} |
|
|
|
|
|
|
|
bool IsInside(FuncGraphPtr, const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node->func_graph()); |
|
|
|
return node->func_graph()->has_flag("inline_inside"); |
|
|
|
} |
|
|
|
|
|
|
|
bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); } |
|
|
|
|
|
|
|
bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; } |
|
|
|
bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &); |
|
|
|
|
|
|
|
bool IsDirectParentCall(FuncGraphPtr fg, AnfNodePtr node) { |
|
|
|
bool unique_use = IsUniqueUse(fg, nullptr); |
|
|
|
bool is_recursive = fg->recursive(); |
|
|
|
if (fg->parent() != nullptr && is_recursive) { |
|
|
|
if (fg->parent() == node->func_graph() && unique_use) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &); |
|
|
|
bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node); |
|
|
|
bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &); |
|
|
|
bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node); |
|
|
|
bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &); |
|
|
|
|
|
|
|
// {G, Xs} |
|
|
|
class InlinerBase : public AnfVisitor { |
|
|
|
public: |
|
|
|
explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions, bool use_move = true) |
|
|
|
explicit InlinerBase(std::vector<std::vector<CriterionFuncType>> criterions, bool use_move = true) |
|
|
|
: use_move_(use_move), criterions_(criterions) {} |
|
|
|
~InlinerBase() override = default; |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
@@ -135,22 +110,24 @@ class InlinerBase : public AnfVisitor { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
Reset(); |
|
|
|
|
|
|
|
// 'criterions_': {criterion_group_1:{criterion1, criterion2, ...}, criterion_group_2:{...}, ...} |
|
|
|
// All the criterions of 'criterion group' are true would set 'criterion group' as 'true'. As [AND]. |
|
|
|
// Anyone of 'criterion group' in 'criterions_' is 'true' would be matched. As [OR]. |
|
|
|
bool is_match = false; |
|
|
|
for (auto &criterion : criterions_) { |
|
|
|
if (!criterion.first(fg, node)) { |
|
|
|
continue; |
|
|
|
for (auto &criterions : criterions_) { // Each 'criterion group' in criterions_. |
|
|
|
is_match = true; |
|
|
|
for (auto &criterion : criterions) { // Each criterion in 'criterion group'. |
|
|
|
if (!criterion(this, fg, node)) { |
|
|
|
is_match = false; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (criterion.second && IsRecursive(fg)) { |
|
|
|
continue; |
|
|
|
if (is_match) { |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
is_match = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
if (!is_match) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
@@ -162,24 +139,19 @@ class InlinerBase : public AnfVisitor { |
|
|
|
if (fg->parameters().size() != args.size()) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto is_unique_use = IsUniqueUse(fg, nullptr); |
|
|
|
// Not to inline after block if it has switch call inside, to avoid switch expansion. |
|
|
|
if (!is_unique_use && fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) { |
|
|
|
auto has_branch_call = GraphHasBranch(fg); |
|
|
|
if (has_branch_call) { |
|
|
|
return TransformBranchCall(fg, node, args); |
|
|
|
if (IsUniqueUse(nullptr, fg, nullptr)) { |
|
|
|
if (use_move_) { |
|
|
|
auto mng = fg->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(mng); |
|
|
|
ReplaceParams(mng, args, fg); |
|
|
|
auto out_node = fg->output(); |
|
|
|
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope()); |
|
|
|
return out_node; |
|
|
|
} |
|
|
|
} else if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK) && GraphHasBranch(fg)) { |
|
|
|
// Not to inline after block if it has switch call inside, to avoid switch expansion. |
|
|
|
return TransformBranchCall(fg, node, args); |
|
|
|
} |
|
|
|
|
|
|
|
if (use_move_ && is_unique_use) { |
|
|
|
auto mng = fg->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(mng); |
|
|
|
ReplaceParams(mng, args, fg); |
|
|
|
auto out_node = fg->output(); |
|
|
|
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope()); |
|
|
|
return out_node; |
|
|
|
} |
|
|
|
|
|
|
|
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope()); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -208,6 +180,7 @@ class InlinerBase : public AnfVisitor { |
|
|
|
is_checked_ = false; |
|
|
|
is_recursive_ = false; |
|
|
|
} |
|
|
|
|
|
|
|
// For after block which contains branch call, delete the parameters which is not used. |
|
|
|
// In most cases, it may be a `Module` or other constant input. |
|
|
|
AnfNodePtr TransformBranchCall(const FuncGraphPtr &fg, const AnfNodePtr &node, const std::vector<AnfNodePtr> &args) { |
|
|
|
@@ -298,25 +271,62 @@ class InlinerBase : public AnfVisitor { |
|
|
|
private: |
|
|
|
bool is_checked_{false}, is_recursive_{false}; |
|
|
|
bool use_move_; |
|
|
|
std::vector<std::pair<CriterionFuncType, bool>> criterions_; |
|
|
|
std::vector<std::vector<CriterionFuncType>> criterions_; |
|
|
|
std::unordered_map<FuncGraphPtr, bool> graph_branch_cache_; |
|
|
|
// Key is the old func graph, and the value is the new func_graph |
|
|
|
std::unordered_map<FuncGraphPtr, FuncGraphPtr> transformed_branch_chache_; |
|
|
|
}; |
|
|
|
|
|
|
|
bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { |
|
|
|
auto &cnodes = fg->func_graph_cnodes_index(); |
|
|
|
int64_t n_use = std::accumulate( |
|
|
|
cnodes.begin(), cnodes.end(), 0, |
|
|
|
[](int64_t sum, const std::pair<const CNodeIndexPairPtr, int64_t> &item) { return sum + item.second; }); |
|
|
|
return n_use == 1; |
|
|
|
} |
|
|
|
|
|
|
|
bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { |
|
|
|
auto n_cnode = fg->nodes().size() - fg->parameters().size(); |
|
|
|
// There is at least one CNode(return, other_node). |
|
|
|
return n_cnode <= 2; |
|
|
|
} |
|
|
|
|
|
|
|
bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node->func_graph()); |
|
|
|
return node->func_graph()->has_flag("inline_inside"); |
|
|
|
} |
|
|
|
|
|
|
|
bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { return fg->has_flag("core"); } |
|
|
|
|
|
|
|
bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node) { |
|
|
|
bool unique_use = IsUniqueUse(nullptr, fg, nullptr); |
|
|
|
bool is_recursive = fg->recursive(); |
|
|
|
if (fg->parent() != nullptr && is_recursive) { |
|
|
|
if (fg->parent() == node->func_graph() && unique_use) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &) { |
|
|
|
return !inliner->IsRecursive(fg); |
|
|
|
} |
|
|
|
|
|
|
|
class Inliner : public InlinerBase { |
|
|
|
public: |
|
|
|
explicit Inliner(bool use_move = true) |
|
|
|
: InlinerBase( |
|
|
|
// Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}. |
|
|
|
{ |
|
|
|
{IsUniqueUse, true}, |
|
|
|
{IsTrivial, false}, |
|
|
|
{IsInside, false}, |
|
|
|
{IsCore, false}, |
|
|
|
{IsDirectParentCall, false}, |
|
|
|
{NoCriterion, true}, |
|
|
|
{IsTrivial}, |
|
|
|
{IsInside}, |
|
|
|
{IsCore}, |
|
|
|
{IsNotRecursive}, |
|
|
|
{IsDirectParentCall}, |
|
|
|
}, |
|
|
|
use_move) {} |
|
|
|
|
|
|
|
~Inliner() override = default; |
|
|
|
}; |
|
|
|
|
|
|
|
@@ -324,8 +334,9 @@ class DirectInliner : public InlinerBase { |
|
|
|
public: |
|
|
|
explicit DirectInliner(bool use_move = true) |
|
|
|
: InlinerBase( |
|
|
|
// Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}. |
|
|
|
{ |
|
|
|
{IsDirectParentCall, false}, |
|
|
|
{IsDirectParentCall}, |
|
|
|
}, |
|
|
|
use_move) {} |
|
|
|
~DirectInliner() override = default; |
|
|
|
|