From: @ginfung Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -24,6 +24,7 @@ | |||||
| #include "abstract/abstract_function.h" | #include "abstract/abstract_function.h" | ||||
| #include "utils/flags.h" | #include "utils/flags.h" | ||||
| #include "utils/utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| /* namespace to support opt */ | /* namespace to support opt */ | ||||
| @@ -32,6 +33,20 @@ using mindspore::abstract::AbstractBase; | |||||
| using mindspore::abstract::AbstractFunction; | using mindspore::abstract::AbstractFunction; | ||||
| using mindspore::abstract::AbstractFunctionPtr; | using mindspore::abstract::AbstractFunctionPtr; | ||||
| bool WithRecomputedScope(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!node->isa<CNode>()) { | |||||
| return false; | |||||
| } | |||||
| auto full_name_with_scope = node->fullname_with_scope(); | |||||
| return full_name_with_scope.find(kAttrRecompute) == 0; | |||||
| } | |||||
| bool IsSetRecomputed(const CNodePtr &a, const CNodePtr &b) { | |||||
| return (WithRecomputedScope(a) && !a->HasAttr(kAttrNeedCseAfterRecompute)) || | |||||
| (WithRecomputedScope(b) && !b->HasAttr(kAttrNeedCseAfterRecompute)); | |||||
| } | |||||
| BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id) { | BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto node_abs = node->abstract(); | auto node_abs = node->abstract(); | ||||
| @@ -83,7 +98,7 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { | |||||
| } else if (node->isa<Parameter>()) { | } else if (node->isa<Parameter>()) { | ||||
| h = node->hash(); | h = node->hash(); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unknow node type"; | |||||
| MS_LOG(ERROR) << "Unknown node type"; | |||||
| } | } | ||||
| hashes[node] = h; | hashes[node] = h; | ||||
| @@ -142,6 +157,10 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool chec | |||||
| } else if (main->isa<CNode>() && node->isa<CNode>()) { | } else if (main->isa<CNode>() && node->isa<CNode>()) { | ||||
| auto c_main = main->cast<CNodePtr>(); | auto c_main = main->cast<CNodePtr>(); | ||||
| auto c_node = node->cast<CNodePtr>(); | auto c_node = node->cast<CNodePtr>(); | ||||
| // Not do cse for the node set recompute before the recompute pass. | |||||
| if (IsSetRecomputed(c_main, c_node)) { | |||||
| return false; | |||||
| } | |||||
| // When appsame is true, check if has side effect, do not merge. | // When appsame is true, check if has side effect, do not merge. | ||||
| if (check_side_effect && HasSideEffect(main)) { | if (check_side_effect && HasSideEffect(main)) { | ||||
| return false; | return false; | ||||
| @@ -25,12 +25,12 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "mindspore/core/base/core_ops.h" | #include "mindspore/core/base/core_ops.h" | ||||
| #include "utils/utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| constexpr auto kGradientsFlag = "Gradients"; | constexpr auto kGradientsFlag = "Gradients"; | ||||
| constexpr auto kAttrRecompute = "recompute"; | |||||
| bool IsBpropNode(const AnfNodePtr &node) { | bool IsBpropNode(const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| @@ -339,6 +339,7 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod | |||||
| auto recomputed_node = graph->NewCNode(new_inputs); | auto recomputed_node = graph->NewCNode(new_inputs); | ||||
| MS_EXCEPTION_IF_NULL(recomputed_node); | MS_EXCEPTION_IF_NULL(recomputed_node); | ||||
| recomputed_node->AddAttr("duplicated", MakeValue(true)); | recomputed_node->AddAttr("duplicated", MakeValue(true)); | ||||
| recomputed_node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true)); | |||||
| recomputed_node->set_abstract(origin_node->abstract()); | recomputed_node->set_abstract(origin_node->abstract()); | ||||
| recomputed_node->set_scope(origin_node->scope()); | recomputed_node->set_scope(origin_node->scope()); | ||||
| origin_to_recomputed_nodes->insert(std::make_pair(origin_node, recomputed_node)); | origin_to_recomputed_nodes->insert(std::make_pair(origin_node, recomputed_node)); | ||||
| @@ -415,6 +416,12 @@ void InsertRecomputedNodes(const FuncGraphPtr &graph) { | |||||
| DuplicateRecomputedNodes(graph, target_nodes, origin_recomputed_nodes, first_target_inputs, | DuplicateRecomputedNodes(graph, target_nodes, origin_recomputed_nodes, first_target_inputs, | ||||
| &origin_to_recomputed_nodes); | &origin_to_recomputed_nodes); | ||||
| } | } | ||||
| // Set need cse attr for doing cse after recompute. | |||||
| for (const auto &node : orders) { | |||||
| if (WithRecomputedScope(node)) { | |||||
| node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true)); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -302,6 +302,11 @@ OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| return map; | return map; | ||||
| } | } | ||||
| OptPassGroupMap GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| OptPassGroupMap map({{"cse", opt::OptPassConfig(opt::CSEPass(false))}}); | |||||
| return map; | |||||
| } | |||||
| static std::unordered_map<std::string, std::shared_ptr<Optimizer>> g_pass_opts = {}; | static std::unordered_map<std::string, std::shared_ptr<Optimizer>> g_pass_opts = {}; | ||||
| void InitOpt(const ResourcePtr &res) { | void InitOpt(const ResourcePtr &res) { | ||||
| @@ -323,6 +328,8 @@ void InitOpt(const ResourcePtr &res) { | |||||
| g_pass_opts["opt_grad_epilogue"] = | g_pass_opts["opt_grad_epilogue"] = | ||||
| Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false); | Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false); | ||||
| g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); | g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); | ||||
| g_pass_opts["opt_after_recompute"] = | |||||
| Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass)); | |||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | ||||
| @@ -367,6 +374,7 @@ bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, | |||||
| bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); } | bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); } | ||||
| bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } | bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } | ||||
| bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } | bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } | ||||
| bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); } | |||||
| bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); } | bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); } | ||||
| @@ -525,7 +533,8 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru | |||||
| {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, | {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, | ||||
| {"add_cache_embedding", AddCacheEmbeddingPass}, | {"add_cache_embedding", AddCacheEmbeddingPass}, | ||||
| {"add_control_depend", AddControlDependPass}, | {"add_control_depend", AddControlDependPass}, | ||||
| {"add_recomputation", AddRecomputationPass}}; | |||||
| {"add_recomputation", AddRecomputationPass}, | |||||
| {"cse_after_recomputation", OptAfterRecomputeGroup}}; | |||||
| std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, | std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, | ||||
| {"opt_a", OptPassAGroup}, | {"opt_a", OptPassAGroup}, | ||||
| @@ -380,6 +380,8 @@ constexpr auto kAttrPadMode = "pad_mode"; | |||||
| constexpr auto kAttrPad = "pad"; | constexpr auto kAttrPad = "pad"; | ||||
| constexpr auto kAttrPadding = "padding"; | constexpr auto kAttrPadding = "padding"; | ||||
| constexpr auto kAttrIsGrad = "is_grad"; | constexpr auto kAttrIsGrad = "is_grad"; | ||||
| constexpr auto kAttrRecompute = "recompute"; | |||||
| constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; | |||||
| // attr value | // attr value | ||||
| constexpr auto kValueTargetSwitch = "target_switch"; | constexpr auto kValueTargetSwitch = "target_switch"; | ||||