| @@ -134,12 +134,24 @@ class InlinerBase : public AnfVisitor { | |||
| std::vector<AnfNodePtr> args; | |||
| (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); | |||
| // compare size to avoid the case that the function has default value after grad. | |||
| // Compare size to avoid the case that the function has default value after grad. | |||
| // for which after renormalize, the function default value will be an input | |||
| if (fg->parameters().size() != args.size()) { | |||
| return nullptr; | |||
| } | |||
| if (IsUniqueUse(nullptr, fg, nullptr)) { | |||
| // The other branch calling the last after block. | |||
| if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) { | |||
| // Check if parameters' changed. | |||
| auto param_simplified_caller = SimplifyAfterParameter(fg, node, args); | |||
| if (param_simplified_caller != nullptr) { | |||
| return param_simplified_caller; | |||
| } | |||
| } | |||
| // For the single used fg, including non-after and after not matched above, | |||
| // we move the whole fg nodes. | |||
| if (use_move_) { | |||
| auto mng = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| @@ -148,10 +160,20 @@ class InlinerBase : public AnfVisitor { | |||
| mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope()); | |||
| return out_node; | |||
| } | |||
| } else if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK) && GraphHasBranch(fg)) { | |||
| // Not to inline after block if it has switch call inside, to avoid switch expansion. | |||
| return TransformBranchCall(fg, node, args); | |||
| } else { | |||
| // We don't expand the middle multiple used after block, except the last one. | |||
| if (GraphHasBranch(fg)) { | |||
| return nullptr; | |||
| } | |||
| // Check if parameters' changed for the first met branch calling. | |||
| if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) { | |||
| auto param_simplified_caller = SimplifyAfterParameter(fg, node, args); | |||
| if (param_simplified_caller != nullptr) { | |||
| return param_simplified_caller; | |||
| } | |||
| } | |||
| } | |||
| // Or, just make a clone for not single used fg. | |||
| return InlineClone(fg, node->func_graph(), args, inputs[0]->scope()); | |||
| } | |||
| @@ -183,37 +205,34 @@ class InlinerBase : public AnfVisitor { | |||
| // For after block which contains branch call, delete the parameters which is not used. | |||
| // In most cases, it may be a `Module` or other constant input. | |||
| AnfNodePtr TransformBranchCall(const FuncGraphPtr &fg, const AnfNodePtr &node, const std::vector<AnfNodePtr> &args) { | |||
| AnfNodePtr SimplifyAfterParameter(const FuncGraphPtr &fg, const AnfNodePtr &node, | |||
| const std::vector<AnfNodePtr> &args) { | |||
| auto &fg_params = fg->parameters(); | |||
| std::vector<int64_t> used_param_index; | |||
| auto mng = fg->manager(); | |||
| bool should_simplify = false; | |||
| for (size_t i = 0; i < fg_params.size(); i++) { | |||
| if (mng->node_users()[fg_params[i]].size() != 0) { | |||
| used_param_index.emplace_back(i); | |||
| } else { | |||
| MS_LOG(DEBUG) << "Not used parameter " << fg_params[i]->DebugString() << " for calling " << fg->ToString(); | |||
| should_simplify = true; | |||
| } | |||
| } | |||
| // If all parameters are used by cnodes | |||
| if (used_param_index.size() == fg_params.size()) { | |||
| if (!should_simplify) { | |||
| return nullptr; | |||
| } | |||
| if (transformed_branch_chache_.find(fg) == transformed_branch_chache_.end()) { | |||
| MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString(); | |||
| // clone a new graph and ignore the not used parameters | |||
| FuncGraphPtr new_fg = TransformableClone(fg); | |||
| auto &new_fg_params = new_fg->parameters(); | |||
| std::vector<AnfNodePtr> new_params; | |||
| std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params), | |||
| [&new_fg_params](size_t i) { return new_fg_params[i]; }); | |||
| new_fg->set_parameters(new_params); | |||
| // New func graph must set FUNC_GRAPH_FLAG_AFTER_BLOCK flag otherwise the new graph will be inlined. | |||
| new_fg->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true); | |||
| // Add new graph to the cache to improve perfomance when call HasBranchCall. | |||
| graph_branch_cache_[new_fg] = true; | |||
| // If a graph be called at two or more locations, it should not be cloned once again, so add it to the cache. | |||
| transformed_branch_chache_[fg] = new_fg; | |||
| } | |||
| MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString(); | |||
| // Clone a new graph and ignore the not used parameters | |||
| auto new_fg = TransformableClone(fg); | |||
| auto &new_fg_params = new_fg->parameters(); | |||
| std::vector<AnfNodePtr> new_params; | |||
| std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params), | |||
| [&new_fg_params](size_t i) { return new_fg_params[i]; }); | |||
| new_fg->set_parameters(new_params); | |||
| std::vector<AnfNodePtr> node_inputs; | |||
| node_inputs.push_back(NewValueNode(transformed_branch_chache_[fg])); | |||
| node_inputs.push_back(NewValueNode(new_fg)); | |||
| std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs), | |||
| [&args](size_t i) { return args[i]; }); | |||
| return node->func_graph()->NewCNode(node_inputs); | |||
| @@ -273,8 +292,6 @@ class InlinerBase : public AnfVisitor { | |||
| bool use_move_; | |||
| std::vector<std::vector<CriterionFuncType>> criterions_; | |||
| std::unordered_map<FuncGraphPtr, bool> graph_branch_cache_; | |||
| // Key is the old func graph, and the value is the new func_graph | |||
| std::unordered_map<FuncGraphPtr, FuncGraphPtr> transformed_branch_chache_; | |||
| }; | |||
| bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { | |||