|
|
@@ -151,6 +151,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { |
|
|
opt::OptPassConfig({resolve_irpass.resolver_resolve_, resolve_irpass.resolver_getattr_, |
|
|
opt::OptPassConfig({resolve_irpass.resolver_resolve_, resolve_irpass.resolver_getattr_, |
|
|
irpass.get_make_ref_eliminate_, irpass.replace_old_param_}); |
|
|
irpass.get_make_ref_eliminate_, irpass.replace_old_param_}); |
|
|
|
|
|
|
|
|
|
|
|
// Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases(). |
|
|
OptPassGroupMap map_a({{"a_1", a_1}, |
|
|
OptPassGroupMap map_a({{"a_1", a_1}, |
|
|
{"a_2", a_2}, |
|
|
{"a_2", a_2}, |
|
|
{"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)}, |
|
|
{"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)}, |
|
|
@@ -272,6 +273,17 @@ OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { |
|
|
return map; |
|
|
return map; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
OptPassGroupMap GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPassLib &irpass) { |
|
|
|
|
|
auto opt_a = GetOptPassesA(irpass); |
|
|
|
|
|
auto a3 = opt_a[opt_a.size() - 1]; |
|
|
|
|
|
OptPassGroupMap map({ |
|
|
|
|
|
{"renormalize", opt::OptPassConfig::Renormalize()}, |
|
|
|
|
|
{"cse", opt::OptPassConfig(opt::CSEPass(false))}, |
|
|
|
|
|
{a3}, |
|
|
|
|
|
}); |
|
|
|
|
|
return map; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
OptPassGroupMap GetInferenceOptPreparePhases() { |
|
|
OptPassGroupMap GetInferenceOptPreparePhases() { |
|
|
opt::irpass::InferenceOptPrepareLib irpass; |
|
|
opt::irpass::InferenceOptPrepareLib irpass; |
|
|
auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_}); |
|
|
auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_}); |
|
|
@@ -303,6 +315,8 @@ void InitOpt(const ResourcePtr &res) { |
|
|
Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false); |
|
|
Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false); |
|
|
g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass)); |
|
|
g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass)); |
|
|
g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true); |
|
|
g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true); |
|
|
|
|
|
g_pass_opts["opt_grad_epilogue"] = |
|
|
|
|
|
Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false); |
|
|
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); |
|
|
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); |
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
@@ -351,6 +365,8 @@ bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepar |
|
|
|
|
|
|
|
|
bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); } |
|
|
bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); } |
|
|
|
|
|
|
|
|
|
|
|
bool OptPassGradEpilogueGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_grad_epilogue"); } |
|
|
|
|
|
|
|
|
bool AddControlDependPass(const ResourcePtr &res) { |
|
|
bool AddControlDependPass(const ResourcePtr &res) { |
|
|
FuncGraphPtr func_graph = res->func_graph(); |
|
|
FuncGraphPtr func_graph = res->func_graph(); |
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
@@ -469,7 +485,8 @@ std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStru |
|
|
{"opt_prepare", PrepareGroup}, |
|
|
{"opt_prepare", PrepareGroup}, |
|
|
{"cconv", CconvPass}}; |
|
|
{"cconv", CconvPass}}; |
|
|
|
|
|
|
|
|
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, |
|
|
|
|
|
|
|
|
std::vector<PassItem> kPynativePasses = {{"opt_grad_epilogue", OptPassGradEpilogueGroup}, |
|
|
|
|
|
{"opt_a", OptPassAGroup}, |
|
|
{"opt_b", OptPassBGroup}, |
|
|
{"opt_b", OptPassBGroup}, |
|
|
{"cconv", CconvPass}, |
|
|
{"cconv", CconvPass}, |
|
|
{"transform_top", TransformTopGraphPass}, |
|
|
{"transform_top", TransformTopGraphPass}, |
|
|
|