|
|
|
@@ -134,12 +134,24 @@ class InlinerBase : public AnfVisitor { |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> args; |
|
|
|
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); |
|
|
|
// compare size to avoid the case that the function has default value after grad. |
|
|
|
// Compare size to avoid the case that the function has default value after grad. |
|
|
|
// for which after renormalize, the function default value will be an input |
|
|
|
if (fg->parameters().size() != args.size()) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
if (IsUniqueUse(nullptr, fg, nullptr)) { |
|
|
|
// The other branch calling the last after block. |
|
|
|
if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) { |
|
|
|
// Check if parameters' changed. |
|
|
|
auto param_simplified_caller = SimplifyAfterParameter(fg, node, args); |
|
|
|
if (param_simplified_caller != nullptr) { |
|
|
|
return param_simplified_caller; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// For the single used fg, including non-after and after not matched above, |
|
|
|
// we move the whole fg nodes. |
|
|
|
if (use_move_) { |
|
|
|
auto mng = fg->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(mng); |
|
|
|
@@ -148,10 +160,20 @@ class InlinerBase : public AnfVisitor { |
|
|
|
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope()); |
|
|
|
return out_node; |
|
|
|
} |
|
|
|
} else if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK) && GraphHasBranch(fg)) { |
|
|
|
// Not to inline after block if it has switch call inside, to avoid switch expansion. |
|
|
|
return TransformBranchCall(fg, node, args); |
|
|
|
} else { |
|
|
|
// We don't expand the middle multiple used after block, except the last one. |
|
|
|
if (GraphHasBranch(fg)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
// Check if parameters' changed for the first met branch calling. |
|
|
|
if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) { |
|
|
|
auto param_simplified_caller = SimplifyAfterParameter(fg, node, args); |
|
|
|
if (param_simplified_caller != nullptr) { |
|
|
|
return param_simplified_caller; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
// Or, just make a clone for not single used fg. |
|
|
|
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope()); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -183,37 +205,34 @@ class InlinerBase : public AnfVisitor { |
|
|
|
|
|
|
|
// For after block which contains branch call, delete the parameters which is not used. |
|
|
|
// In most cases, it may be a `Module` or other constant input. |
|
|
|
AnfNodePtr TransformBranchCall(const FuncGraphPtr &fg, const AnfNodePtr &node, const std::vector<AnfNodePtr> &args) { |
|
|
|
AnfNodePtr SimplifyAfterParameter(const FuncGraphPtr &fg, const AnfNodePtr &node, |
|
|
|
const std::vector<AnfNodePtr> &args) { |
|
|
|
auto &fg_params = fg->parameters(); |
|
|
|
std::vector<int64_t> used_param_index; |
|
|
|
auto mng = fg->manager(); |
|
|
|
bool should_simplify = false; |
|
|
|
for (size_t i = 0; i < fg_params.size(); i++) { |
|
|
|
if (mng->node_users()[fg_params[i]].size() != 0) { |
|
|
|
used_param_index.emplace_back(i); |
|
|
|
} else { |
|
|
|
MS_LOG(DEBUG) << "Not used parameter " << fg_params[i]->DebugString() << " for calling " << fg->ToString(); |
|
|
|
should_simplify = true; |
|
|
|
} |
|
|
|
} |
|
|
|
// If all parameters are used by cnodes |
|
|
|
if (used_param_index.size() == fg_params.size()) { |
|
|
|
if (!should_simplify) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (transformed_branch_chache_.find(fg) == transformed_branch_chache_.end()) { |
|
|
|
MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString(); |
|
|
|
// clone a new graph and ignore the not used parameters |
|
|
|
FuncGraphPtr new_fg = TransformableClone(fg); |
|
|
|
auto &new_fg_params = new_fg->parameters(); |
|
|
|
std::vector<AnfNodePtr> new_params; |
|
|
|
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params), |
|
|
|
[&new_fg_params](size_t i) { return new_fg_params[i]; }); |
|
|
|
new_fg->set_parameters(new_params); |
|
|
|
// New func graph must set FUNC_GRAPH_FLAG_AFTER_BLOCK flag otherwise the new graph will be inlined. |
|
|
|
new_fg->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true); |
|
|
|
// Add new graph to the cache to improve perfomance when call HasBranchCall. |
|
|
|
graph_branch_cache_[new_fg] = true; |
|
|
|
// If a graph be called at two or more locations, it should not be cloned once again, so add it to the cache. |
|
|
|
transformed_branch_chache_[fg] = new_fg; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString(); |
|
|
|
// Clone a new graph and ignore the not used parameters |
|
|
|
auto new_fg = TransformableClone(fg); |
|
|
|
auto &new_fg_params = new_fg->parameters(); |
|
|
|
std::vector<AnfNodePtr> new_params; |
|
|
|
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params), |
|
|
|
[&new_fg_params](size_t i) { return new_fg_params[i]; }); |
|
|
|
new_fg->set_parameters(new_params); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> node_inputs; |
|
|
|
node_inputs.push_back(NewValueNode(transformed_branch_chache_[fg])); |
|
|
|
node_inputs.push_back(NewValueNode(new_fg)); |
|
|
|
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs), |
|
|
|
[&args](size_t i) { return args[i]; }); |
|
|
|
return node->func_graph()->NewCNode(node_inputs); |
|
|
|
@@ -273,8 +292,6 @@ class InlinerBase : public AnfVisitor { |
|
|
|
bool use_move_; |
|
|
|
std::vector<std::vector<CriterionFuncType>> criterions_; |
|
|
|
std::unordered_map<FuncGraphPtr, bool> graph_branch_cache_; |
|
|
|
// Key is the old func graph, and the value is the new func_graph |
|
|
|
std::unordered_map<FuncGraphPtr, FuncGraphPtr> transformed_branch_chache_; |
|
|
|
}; |
|
|
|
|
|
|
|
bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { |
|
|
|
|