From: @ginfung Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -24,6 +24,7 @@ | |||
| #include "abstract/abstract_function.h" | |||
| #include "utils/flags.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore { | |||
| /* namespace to support opt */ | |||
| @@ -32,6 +33,20 @@ using mindspore::abstract::AbstractBase; | |||
| using mindspore::abstract::AbstractFunction; | |||
| using mindspore::abstract::AbstractFunctionPtr; | |||
| bool WithRecomputedScope(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| auto full_name_with_scope = node->fullname_with_scope(); | |||
| return full_name_with_scope.find(kAttrRecompute) == 0; | |||
| } | |||
| bool IsSetRecomputed(const CNodePtr &a, const CNodePtr &b) { | |||
| return (WithRecomputedScope(a) && !a->HasAttr(kAttrNeedCseAfterRecompute)) || | |||
| (WithRecomputedScope(b) && !b->HasAttr(kAttrNeedCseAfterRecompute)); | |||
| } | |||
| BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto node_abs = node->abstract(); | |||
| @@ -83,7 +98,7 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { | |||
| } else if (node->isa<Parameter>()) { | |||
| h = node->hash(); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unknow node type"; | |||
| MS_LOG(ERROR) << "Unknown node type"; | |||
| } | |||
| hashes[node] = h; | |||
| @@ -142,6 +157,10 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool chec | |||
| } else if (main->isa<CNode>() && node->isa<CNode>()) { | |||
| auto c_main = main->cast<CNodePtr>(); | |||
| auto c_node = node->cast<CNodePtr>(); | |||
| // Not do cse for the node set recompute before the recompute pass. | |||
| if (IsSetRecomputed(c_main, c_node)) { | |||
| return false; | |||
| } | |||
| // When appsame is true, check if has side effect, do not merge. | |||
| if (check_side_effect && HasSideEffect(main)) { | |||
| return false; | |||
| @@ -25,12 +25,12 @@ | |||
| #include <algorithm> | |||
| #include "ir/func_graph.h" | |||
| #include "mindspore/core/base/core_ops.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr auto kGradientsFlag = "Gradients"; | |||
| constexpr auto kAttrRecompute = "recompute"; | |||
| bool IsBpropNode(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| @@ -339,6 +339,7 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod | |||
| auto recomputed_node = graph->NewCNode(new_inputs); | |||
| MS_EXCEPTION_IF_NULL(recomputed_node); | |||
| recomputed_node->AddAttr("duplicated", MakeValue(true)); | |||
| recomputed_node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true)); | |||
| recomputed_node->set_abstract(origin_node->abstract()); | |||
| recomputed_node->set_scope(origin_node->scope()); | |||
| origin_to_recomputed_nodes->insert(std::make_pair(origin_node, recomputed_node)); | |||
| @@ -415,6 +416,12 @@ void InsertRecomputedNodes(const FuncGraphPtr &graph) { | |||
| DuplicateRecomputedNodes(graph, target_nodes, origin_recomputed_nodes, first_target_inputs, | |||
| &origin_to_recomputed_nodes); | |||
| } | |||
| // Set need cse attr for doing cse after recompute. | |||
| for (const auto &node : orders) { | |||
| if (WithRecomputedScope(node)) { | |||
| node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true)); | |||
| } | |||
| } | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -302,6 +302,11 @@ OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| return map; | |||
| } | |||
| OptPassGroupMap GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| OptPassGroupMap map({{"cse", opt::OptPassConfig(opt::CSEPass(false))}}); | |||
| return map; | |||
| } | |||
| static std::unordered_map<std::string, std::shared_ptr<Optimizer>> g_pass_opts = {}; | |||
| void InitOpt(const ResourcePtr &res) { | |||
| @@ -323,6 +328,8 @@ void InitOpt(const ResourcePtr &res) { | |||
| g_pass_opts["opt_grad_epilogue"] = | |||
| Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false); | |||
| g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); | |||
| g_pass_opts["opt_after_recompute"] = | |||
| Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass)); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | |||
| @@ -367,6 +374,7 @@ bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, | |||
| bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); } | |||
| bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } | |||
| bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } | |||
| bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); } | |||
| bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); } | |||
| @@ -525,7 +533,8 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru | |||
| {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, | |||
| {"add_cache_embedding", AddCacheEmbeddingPass}, | |||
| {"add_control_depend", AddControlDependPass}, | |||
| {"add_recomputation", AddRecomputationPass}}; | |||
| {"add_recomputation", AddRecomputationPass}, | |||
| {"cse_after_recomputation", OptAfterRecomputeGroup}}; | |||
| std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, | |||
| {"opt_a", OptPassAGroup}, | |||
| @@ -380,6 +380,8 @@ constexpr auto kAttrPadMode = "pad_mode"; | |||
| constexpr auto kAttrPad = "pad"; | |||
| constexpr auto kAttrPadding = "padding"; | |||
| constexpr auto kAttrIsGrad = "is_grad"; | |||
| constexpr auto kAttrRecompute = "recompute"; | |||
| constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; | |||
| // attr value | |||
| constexpr auto kValueTargetSwitch = "target_switch"; | |||