/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "pipeline/pass.h" #include #include #include #include #include #include #include #include "ir/func_graph_cloner.h" #include "pipeline/parse/parse_base.h" #include "pipeline/parse/data_converter.h" #include "pipeline/resource.h" #include "pipeline/validator.h" #include "optimizer/optimizer.h" #include "optimizer/cse.h" #include "optimizer/clean.h" #include "optimizer/irpass.h" #include "optimizer/control_depend.h" #include "parallel/step_parallel.h" #include "parallel/step_auto_parallel.h" #include "parallel/allreduce_fusion/step_allreduce_fusion.h" #include "utils/any.h" namespace mindspore { namespace pipeline { using OptPassGroupMap = opt::OptPassGroupMap; using Optimizer = opt::Optimizer; using CompileGraphs = compile::CompileGraphs; using abstract::AnalysisResult; using mindspore::abstract::AnalysisContextPtr; using mindspore::validator::Validate; bool SimplifyDataStructuresPass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); opt::SimplifyDataStructures(func_graph, res->manager()); abstract::AbstractBasePtrList args_spec; auto parameters = func_graph->parameters(); (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); res->set_func_graph(new_fg); res->set_args_spec(args_spec); return true; } namespace { OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig a_1 = opt::OptPassConfig({ irpass.switch_simplify_, // Safe inlining irpass.inline_, irpass.partial_eliminate_, irpass.replace_applicator_, // Specialization irpass.specialize_transform_, // Arithmetic simplifications irpass.arithmetic_simplify_, irpass.addn_zero_filter_, // Miscellaneous irpass.item_tuple_eliminate_, irpass.env_get_set_item_, irpass.new_env_get_item_, irpass.add_env_get_item_, irpass.cast_eliminate_, irpass.reshape_eliminate_, irpass.reduce_eliminate_, irpass.tile_eliminate_, irpass.transpose_eliminate_, irpass.minmaximum_grad_, irpass.get_make_ref_eliminate_, }); opt::OptPassConfig a_2 = opt::OptPassConfig({ irpass.merge_addn_, irpass.float_tuple_getitem_switch_, irpass.float_env_getitem_switch_, irpass.incorporate_getitem_, irpass.incorporate_getitem_switch_, irpass.incorporate_call_, irpass.incorporate_call_switch_, irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.new_env_get_item_, }); opt::OptPassConfig a_3 = opt::OptPassConfig({ irpass.same_eliminate_, irpass.check_bprop_eliminate_, irpass.replace_applicator_, }); opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); opt::irpass::ResolveIRPassLib resolve_irpass; opt::OptPassConfig resolve_pass = opt::OptPassConfig({resolve_irpass.resolver_resolve_, resolve_irpass.resolver_getattr_, irpass.get_make_ref_eliminate_, irpass.replace_old_param_}); OptPassGroupMap map_a({{"a_1", a_1}, {"a_2", a_2}, {"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)}, {"parallel", opt::OptPassConfig(parallel::StepParallel)}, {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, {"virtual_dataset", virtual_dataset}, {"grad", grad}, {"resolve", resolve_pass}, {"renormalize", opt::OptPassConfig::Renormalize()}, {"cse", opt::OptPassConfig(opt::CSE(false))}, {"a_3", a_3}}); return map_a; } OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig b_1 = opt::OptPassConfig({ irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.stop_gradient_eliminate_, irpass.get_make_ref_eliminate_, }); opt::OptPassConfig b_2 = opt::OptPassConfig({ irpass.replace_refkey_by_param_, irpass.make_ref_eliminate_, }); OptPassGroupMap map({ {"b_1", b_1}, {"b_2", b_2}, {"renormalize", opt::OptPassConfig::Renormalize()}, {"cse", opt::OptPassConfig(opt::CSE(false))}, }); return map; } OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}); OptPassGroupMap map({ {"control_group", control_group}, {"renormalize", opt::OptPassConfig::Renormalize()}, }); return map; } OptPassGroupMap GetInferenceOptPreparePhases() { opt::irpass::InferenceOptPrepareLib irpass; auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_}); opt::OptPassGroupMap prepare_map({{"inference_opt_prep", grad_var_prepare}}); return prepare_map; } OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_}); OptPassGroupMap map({{"prepare_group", prepare_group}}); return map; } static std::unordered_map> g_pass_opts = {}; void InitOpt(const ResourcePtr &res) { if (g_pass_opts.size() == 0) { opt::irpass::OptimizeIRPassLib irpass; g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true); g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); } } } // namespace void ReclaimOptimizer() { for (auto &opt : g_pass_opts) { opt.second = nullptr; } g_pass_opts.clear(); } bool OptPassGroup(const ResourcePtr &res, const std::string &name) { if (res->func_graph() == nullptr) { MS_LOG(ERROR) << "Opt passes int error"; return false; } FuncGraphPtr func_graph = res->func_graph(); MS_LOG(DEBUG) << "Start " << name << " func graph:" << func_graph->ToString() << ", " << func_graph->get_return()->DebugString(true); InitOpt(res); if (g_pass_opts.find(name) != g_pass_opts.end()) { res->set_func_graph(g_pass_opts[name]->step(func_graph)); } // Note: StepParallel may modify the AbstractValue of the parameters of func_graph, but they are not updated to // res->args_spec_ yet. So if any later pass or action want to use that variable, it should be set here. return true; } bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } bool AddControlDependPass(const ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); if (func_graph->has_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER)) { opt::AddControlDepend(func_graph); } for (auto fg : func_graph->func_graphs_used_total()) { MS_EXCEPTION_IF_NULL(fg); if (fg->has_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER)) { opt::AddControlDepend(fg); } } return true; } bool CconvPass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); FuncGraphPtr new_fg = LiftingClone(func_graph); res->set_func_graph(new_fg); return true; } bool ValidatePass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); Validate(func_graph); return true; } bool InferenceOptPreparePass(const ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); auto prepare_map = GetInferenceOptPreparePhases(); auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map); (void)infer_opt_prepare->step(func_graph, false); return true; } std::vector kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, {"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass}, {"cconv", CconvPass}}; std::vector kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, {"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass}, {"opt_control", ControlGroup}, {"opt_prepare", PrepareGroup}, {"cconv", CconvPass}}; } // namespace pipeline } // namespace mindspore