| @@ -151,6 +151,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| opt::OptPassConfig({resolve_irpass.resolver_resolve_, resolve_irpass.resolver_getattr_, | |||
| irpass.get_make_ref_eliminate_, irpass.replace_old_param_}); | |||
| // Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases(). | |||
| OptPassGroupMap map_a({{"a_1", a_1}, | |||
| {"a_2", a_2}, | |||
| {"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)}, | |||
| @@ -272,6 +273,17 @@ OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| return map; | |||
| } | |||
| OptPassGroupMap GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| auto opt_a = GetOptPassesA(irpass); | |||
| auto a3 = opt_a[opt_a.size() - 1]; | |||
| OptPassGroupMap map({ | |||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | |||
| {"cse", opt::OptPassConfig(opt::CSEPass(false))}, | |||
| {a3}, | |||
| }); | |||
| return map; | |||
| } | |||
| OptPassGroupMap GetInferenceOptPreparePhases() { | |||
| opt::irpass::InferenceOptPrepareLib irpass; | |||
| auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_}); | |||
| @@ -303,6 +315,8 @@ void InitOpt(const ResourcePtr &res) { | |||
| Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false); | |||
| g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass)); | |||
| g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true); | |||
| g_pass_opts["opt_grad_epilogue"] = | |||
| Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false); | |||
| g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| @@ -351,6 +365,8 @@ bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepar | |||
| bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); } | |||
| bool OptPassGradEpilogueGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_grad_epilogue"); } | |||
| bool AddControlDependPass(const ResourcePtr &res) { | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -469,7 +485,8 @@ std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStru | |||
| {"opt_prepare", PrepareGroup}, | |||
| {"cconv", CconvPass}}; | |||
| std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, | |||
| std::vector<PassItem> kPynativePasses = {{"opt_grad_epilogue", OptPassGradEpilogueGroup}, | |||
| {"opt_a", OptPassAGroup}, | |||
| {"opt_b", OptPassBGroup}, | |||
| {"cconv", CconvPass}, | |||
| {"transform_top", TransformTopGraphPass}, | |||