Merge pull request !5692 from BowenK/pre_adtags/v1.0.0
| @@ -49,7 +49,15 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { | |||
| auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() + | |||
| grad_op_child_scope_prefix + prim->name()); | |||
| ScopeGuard scope_guard(scope); | |||
| py::function fn = prim->is_base() ? GetBpropFunction(prim->name()) : prim->cast<PrimitivePyPtr>()->GetBpropFunction(); | |||
| py::function fn; | |||
| if (prim->is_base()) { | |||
| fn = GetBpropFunction(prim->name()); | |||
| } else { | |||
| fn = prim->cast<PrimitivePyPtr>()->GetBpropFunction(); | |||
| if (py::isinstance<py::none>(fn)) { | |||
| fn = GetBpropFunction(prim->name()); | |||
| } | |||
| } | |||
| if (!fn || py::isinstance<py::none>(fn)) { | |||
| MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << "."; | |||
| return nullptr; | |||
| @@ -35,8 +35,10 @@ namespace internal { | |||
| const char PARAMETER_MODULE[] = "mindspore.common.parameter"; | |||
| const char PARAMETER_CLASS[] = "Parameter"; | |||
| const char SET_PARAM[] = "__setattr__"; | |||
| AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph); | |||
| AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res); | |||
| AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph, | |||
| const FuncGraphPtr &top_graph); | |||
| AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph, | |||
| const MatchResultPtr &res); | |||
| void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input, | |||
| bool requires_grad, bool layerwise_parallel); | |||
| @@ -72,7 +74,8 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res) | |||
| return std::make_shared<ValueNode>(input_tensor); | |||
| } | |||
| AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg) { | |||
| AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg, | |||
| const FuncGraphPtr &top_graph) { | |||
| auto call_pattern = pattern->cast<CallPtr>(); | |||
| MS_EXCEPTION_IF_NULL(call_pattern); | |||
| auto prim = call_pattern->prim_value(); | |||
| @@ -81,20 +84,20 @@ AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultP | |||
| } | |||
| auto prim_pattern = call_pattern->prim_pattern(); | |||
| MS_EXCEPTION_IF_NULL(prim_pattern); | |||
| return ProcessSinglePattern(prim_pattern, res, fg); | |||
| return ProcessSinglePattern(prim_pattern, res, fg, top_graph); | |||
| } | |||
| AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) { | |||
| AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &top_graph) { | |||
| auto new_para_pattern = pattern->cast<NewParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(new_para_pattern); | |||
| if (!new_para_pattern->built()) { | |||
| static int parameter_id = 0; | |||
| auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name() + std::to_string(parameter_id++); | |||
| auto para_node = std::make_shared<Parameter>(func_graph); | |||
| auto para_node = std::make_shared<Parameter>(top_graph); | |||
| MS_EXCEPTION_IF_NULL(para_node); | |||
| para_node->set_name(para_name); | |||
| // Set function graph | |||
| para_node->set_func_graph(func_graph); | |||
| para_node->set_func_graph(top_graph); | |||
| // Set Debug Info | |||
| auto debug_info = std::make_shared<NodeDebugInfo>(para_name); | |||
| para_node->set_debug_info(debug_info); | |||
| @@ -103,7 +106,7 @@ AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &re | |||
| MS_EXCEPTION_IF_NULL(default_value); | |||
| para_node->set_abstract(default_value->ToAbstract()->Broaden()); | |||
| res->add_entry(pattern, para_node); | |||
| func_graph->add_parameter(para_node); | |||
| top_graph->add_parameter(para_node); | |||
| // Reflect back to Cell._params | |||
| internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(), | |||
| new_para_pattern->layerwise_parallel()); | |||
| @@ -126,7 +129,8 @@ AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) { | |||
| return std::make_shared<ValueNode>(scalar_value_ptr); | |||
| } | |||
| AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) { | |||
| AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph, | |||
| const FuncGraphPtr &top_graph) { | |||
| auto target_node = res->get_node(pattern); | |||
| if (target_node != nullptr) { | |||
| // If pattern is NewParameter, check whether it shouldn't last and is not built | |||
| @@ -141,9 +145,10 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr | |||
| } else if (pattern->isa<NewTensor>()) { | |||
| return BuildNewTensor(pattern, res); | |||
| } else if (pattern->isa<Call>()) { | |||
| return BuildPrimitiveValueNode(pattern, res, func_graph); | |||
| return BuildPrimitiveValueNode(pattern, res, func_graph, top_graph); | |||
| } else if (pattern->isa<NewParameter>()) { | |||
| return BuildNewParameter(pattern, res, func_graph); | |||
| // Add new parameter to top graph instead of current graph | |||
| return BuildNewParameter(pattern, res, top_graph); | |||
| } else if (pattern->isa<Imm>()) { | |||
| return BuildImmNode(pattern, res); | |||
| } else { | |||
| @@ -154,17 +159,18 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr | |||
| } | |||
| AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res, | |||
| const FuncGraphPtr &func_graph) { | |||
| const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph) { | |||
| if (pattern->isa<Call>()) { | |||
| return BuildPrimitiveValueNode(pattern, res, func_graph); | |||
| return BuildPrimitiveValueNode(pattern, res, func_graph, top_graph); | |||
| } | |||
| return nullptr; | |||
| } | |||
| AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res) { | |||
| AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph, | |||
| const MatchResultPtr &res) { | |||
| auto target_inputs = pattern->inputs(); | |||
| if (target_inputs.size() == 0) { | |||
| auto new_node = ProcessSinglePattern(pattern, res, func_graph); | |||
| auto new_node = ProcessSinglePattern(pattern, res, func_graph, top_graph); | |||
| if (new_node != nullptr) { | |||
| res->add_entry(pattern, new_node); | |||
| } | |||
| @@ -172,14 +178,14 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph | |||
| } | |||
| // Build up the AnfNode in a recursive manner | |||
| std::vector<AnfNodePtr> new_inputs; | |||
| auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph); | |||
| auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph, top_graph); | |||
| MS_EXCEPTION_IF_NULL(prim_value_node); | |||
| new_inputs.push_back(prim_value_node); | |||
| for (auto &iter : target_inputs) { | |||
| if (iter == pattern) { | |||
| MS_LOG(EXCEPTION) << "Circle references. Got pattern: " + pattern->unique_name() + "\n"; | |||
| } | |||
| auto input_node = BuildTarget(iter, func_graph, res); | |||
| auto input_node = BuildTarget(iter, func_graph, top_graph, res); | |||
| if (input_node == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failed to build input node for pattern : " + iter->unique_name() + "\n"; | |||
| } | |||
| @@ -240,11 +246,12 @@ void Reset(PatternPtr pattern) { | |||
| } // namespace internal | |||
| AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) { | |||
| AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph, const AnfNodePtr &node, | |||
| const MatchResultPtr &res) { | |||
| auto match_res = src_pattern_->match(node); | |||
| if (match_res != nullptr) { | |||
| res->merge(match_res); | |||
| auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res); | |||
| auto new_node = internal::BuildTarget(dst_pattern_, func_graph, top_graph, res); | |||
| internal::Reset(dst_pattern()); | |||
| return new_node; | |||
| } | |||
| @@ -284,16 +291,19 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) | |||
| } | |||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->AddFuncGraph(func_graph); | |||
| auto graph_nodes_sorted = TopoSort(func_graph->output()); | |||
| auto func_graphs = manager->func_graphs(); | |||
| bool changes = false; | |||
| // Traverse once | |||
| for (auto &node : graph_nodes_sorted) { | |||
| AnfNodePtr new_node = Run(func_graph, node, res); | |||
| if (new_node != nullptr && new_node != node) { | |||
| (void)manager->Replace(node, new_node); | |||
| changes = true; | |||
| for (auto &fg : func_graphs) { | |||
| manager->AddFuncGraph(fg); | |||
| auto graph_nodes_sorted = TopoSort(fg->output()); | |||
| // Traverse once | |||
| for (auto &node : graph_nodes_sorted) { | |||
| AnfNodePtr new_node = Run(fg, func_graph, node, res); | |||
| if (new_node != nullptr && new_node != node) { | |||
| MS_LOG(WARNING) << "Matched"; | |||
| (void)manager->Replace(node, new_node); | |||
| changes = true; | |||
| } | |||
| } | |||
| } | |||
| return changes; | |||
| @@ -39,7 +39,8 @@ class PythonPass { | |||
| ~PythonPass() = default; | |||
| bool Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res); | |||
| std::string name() const { return name_; } | |||
| AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res); | |||
| AnfNodePtr Run(const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph, const AnfNodePtr &node, | |||
| const MatchResultPtr &res); | |||
| PatternPtr src_pattern() { return src_pattern_; } | |||
| PatternPtr dst_pattern() { return dst_pattern_; } | |||
| @@ -43,15 +43,19 @@ PyPassManagerPtr PyPassManager::GetInstance() { | |||
| } | |||
| PyPassManager::PyPassManager() { | |||
| phase_to_group_[Phase::RESOLVE] = std::make_shared<PassGroup>(); | |||
| phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>(); | |||
| phase_to_group_[Phase::PREAD] = std::make_shared<PassGroup>("Pre_AD_PassGroup"); | |||
| phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>("After_OPT_PassGroup"); | |||
| res_ = std::make_shared<MatchResult>(); | |||
| } | |||
| void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, | |||
| bool run_only_once) { | |||
| // NOTE: remove phase option to avoid unnecessary confusion. | |||
| auto cur_pg = GetPassGroup(Phase::OPT); | |||
| bool requires_grad, bool run_only_once) { | |||
| PassGroupPtr cur_pg; | |||
| if (requires_grad) { | |||
| cur_pg = GetPassGroup(Phase::PREAD); | |||
| } else { | |||
| cur_pg = GetPassGroup(Phase::OPT); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(cur_pg); | |||
| cur_pg->SetRunOnlyOnce(run_only_once); | |||
| MS_EXCEPTION_IF_NULL(pattern); | |||
| @@ -62,11 +66,13 @@ void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &patt | |||
| } | |||
| void PyPassManager::Unregiste(const std::string &pass_name) { | |||
| // NOTE: remove phase option to avoid unnecessary confusion. | |||
| auto cur_pm = GetPassGroup(Phase::OPT); | |||
| MS_EXCEPTION_IF_NULL(cur_pm); | |||
| if (!cur_pm->DeletePass(pass_name)) { | |||
| MS_LOG(WARNING) << "No such pass : " + pass_name + "\n"; | |||
| auto opt_pm = GetPassGroup(Phase::OPT); | |||
| if (!opt_pm->DeletePass(pass_name)) { | |||
| MS_LOG(WARNING) << "Opt has no such pass : " + pass_name + "\n"; | |||
| } | |||
| auto pre_ad_pm = GetPassGroup(Phase::PREAD); | |||
| if (!pre_ad_pm->DeletePass(pass_name)) { | |||
| MS_LOG(WARNING) << "Pre_AD has no such pass : " + pass_name + "\n"; | |||
| } | |||
| } | |||
| @@ -92,7 +98,7 @@ void PyPassManager::ClearRes() { | |||
| REGISTER_PYBIND_DEFINE( | |||
| PyPassManager_, ([](const py::module *m) { | |||
| (void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("resolve", Phase::RESOLVE).value("opt", Phase::OPT); | |||
| (void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("pre_ad", Phase::PREAD).value("opt", Phase::OPT); | |||
| (void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_") | |||
| .def(py::init([]() { return PyPassManager::GetInstance(); })) | |||
| .def("registe", &PyPassManager::Registe, "Registe python pass") | |||
| @@ -38,7 +38,7 @@ namespace python_pass { | |||
| class PyPassManager; | |||
| using PyPassManagerPtr = std::shared_ptr<PyPassManager>; | |||
| enum Phase { RESOLVE, OPT }; | |||
| enum Phase { PREAD, OPT }; | |||
| class PyPassManager { | |||
| protected: | |||
| @@ -52,8 +52,8 @@ class PyPassManager { | |||
| // Access the only global instance | |||
| static PyPassManagerPtr GetInstance(); | |||
| virtual ~PyPassManager() = default; | |||
| void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, | |||
| bool run_only_once = false); | |||
| void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, bool requires_grad, | |||
| bool run_only_once); | |||
| void Unregiste(const std::string &pass_name); | |||
| void GenNewParameter(const PatternPtr ¶meter); | |||
| PassGroupPtr GetPassGroup(Phase phase); | |||
| @@ -301,6 +301,8 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes) | |||
| return true; | |||
| } | |||
| bool OptInlineAction(const ResourcePtr &res) { return OptimizeAction(res, kInlinePasses); } | |||
| bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); } | |||
| bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); } | |||
| @@ -473,7 +475,12 @@ bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) { | |||
| return ppm->GetPassGroup(phase)->Run(res->func_graph()); | |||
| } | |||
| bool ResolveActionPyStub(const ResourcePtr &res) { return true || ActionPyStub(res, opt::python_pass::Phase::RESOLVE); } | |||
| bool PreAdActionPyStub(const ResourcePtr &res) { | |||
| if (!ActionPyStub(res, opt::python_pass::Phase::PREAD)) { | |||
| MS_LOG(DEBUG) << "No Match."; | |||
| } | |||
| return true; | |||
| } | |||
| bool OptActionVmPyStub(const ResourcePtr &res) { | |||
| if (ActionPyStub(res, opt::python_pass::Phase::OPT)) { | |||
| @@ -529,12 +536,14 @@ static std::vector<ActionItem> CommonPipeline() { | |||
| if (!multi_graphs) { | |||
| actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); | |||
| } | |||
| // Add resolve-stage python pass stub | |||
| actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub)); | |||
| actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); | |||
| // Evaluate type and shape, and specialize | |||
| actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); | |||
| // Do data structure simplifications and inline | |||
| actions.emplace_back(std::make_pair("inline", OptInlineAction)); | |||
| // Add pre-ad, post-inline python pass stub | |||
| actions.emplace_back(std::make_pair("py_pre_ad", PreAdActionPyStub)); | |||
| return actions; | |||
| } | |||
| @@ -165,6 +165,12 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| return map_a; | |||
| } | |||
| OptPassGroupMap GetA1A2(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| auto opt_a = GetOptPassesA(irpass); | |||
| OptPassGroupMap a1_a2({opt_a[0], opt_a[1]}); | |||
| return a1_a2; | |||
| } | |||
| OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| opt::OptPassConfig c_1 = opt::OptPassConfig({ | |||
| // Safe inlining, | |||
| @@ -270,6 +276,7 @@ static std::unordered_map<std::string, std::shared_ptr<Optimizer>> g_pass_opts = | |||
| void InitOpt(const ResourcePtr &res) { | |||
| if (g_pass_opts.size() == 0) { | |||
| opt::irpass::OptimizeIRPassLib irpass; | |||
| g_pass_opts["a1a2"] = Optimizer::MakeOptimizer("a1a2", res, GetA1A2(irpass)); | |||
| g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); | |||
| g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); | |||
| g_pass_opts["opt_after_cconv"] = | |||
| @@ -318,6 +325,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) { | |||
| return true; | |||
| } | |||
| bool OptPassA1A2(const ResourcePtr &res) { return OptPassGroup(res, "a1a2"); } | |||
| bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } | |||
| bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } | |||
| bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); } | |||
| @@ -440,5 +448,7 @@ std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, | |||
| {"cconv", CconvPass}, | |||
| {"transform_top", TransformTopGraphPass}, | |||
| {"transform_graph", OptPassTransformGraphGroup}}; | |||
| std::vector<PassItem> kInlinePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, {"a1a2", OptPassA1A2}}; | |||
| } // namespace pipeline | |||
| } // namespace mindspore | |||
| @@ -29,6 +29,7 @@ using PassItem = std::pair<std::string, std::function<bool(ResourcePtr)>>; | |||
| extern std::vector<PassItem> kGePasses; | |||
| extern std::vector<PassItem> kVmPasses; | |||
| extern std::vector<PassItem> kInlinePasses; | |||
| extern std::vector<PassItem> kPynativePasses; | |||
| bool CconvPass(const ResourcePtr &res); | |||
| @@ -13,14 +13,14 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Reference for python pass registration.""" | |||
| from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\ | |||
| set_reopt | |||
| from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, _set_renorm,\ | |||
| _set_reopt | |||
| __all__ = [ | |||
| "registe_pass", | |||
| "unregiste_pass", | |||
| "gen_new_parameter", | |||
| "cancel_new_parameter", | |||
| "set_renorm", | |||
| "set_reopt" | |||
| "_set_renorm", | |||
| "_set_reopt" | |||
| ] | |||
| @@ -23,22 +23,26 @@ __all__ = [ | |||
| "unregiste_pass", | |||
| "gen_new_parameter", | |||
| "cancel_new_parameter", | |||
| "set_renorm", | |||
| "set_reopt" | |||
| "_set_renorm", | |||
| "_set_reopt" | |||
| ] | |||
| class PyPassManager(PyPassManager_): | |||
| r""" | |||
| Used to registe and unregiste python passes which can be used to alter graphs. | |||
| Args: | |||
| requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True | |||
| run_only_once (bool): Specify whether or not to run pass only once. Default: False. | |||
| Raises: | |||
| TypeError: If argument has invalid type. | |||
| """ | |||
| def __init__(self, run_only_once=False): | |||
| def __init__(self, requires_grad=True, run_only_once=False): | |||
| if not isinstance(requires_grad, bool): | |||
| raise TypeError(f"Expect bool, got : ({type(requires_grad)}){requires_grad}") | |||
| if not isinstance(run_only_once, bool): | |||
| raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}") | |||
| self.requires_grad = requires_grad | |||
| self.run_only_once_ = run_only_once | |||
| PyPassManager_.__init__(self) | |||
| @@ -51,7 +55,7 @@ class PyPassManager(PyPassManager_): | |||
| raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}") | |||
| if not isinstance(target, Pattern): | |||
| raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}") | |||
| super().registe(pass_name, pattern, target, self.run_only_once_) | |||
| super().registe(pass_name, pattern, target, self.requires_grad, self.run_only_once_) | |||
| def unregiste(self, py_pass): | |||
| if isinstance(py_pass, str): | |||
| @@ -81,11 +85,12 @@ class PyPassManager(PyPassManager_): | |||
| raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}") | |||
| super().set_reopt(do_reopt) | |||
| def registe_pass(run_only_once=False): | |||
| def registe_pass(requires_grad=True, run_only_once=False): | |||
| """ | |||
| Registe python pass to specified pipeline phase which would be used in compilation. | |||
| Args: | |||
| requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True | |||
| run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False. | |||
| Returns: | |||
| @@ -99,7 +104,7 @@ def registe_pass(run_only_once=False): | |||
| >>> target = IsPrimTypeOf("ReLU6") | |||
| >>> return pattern, target | |||
| """ | |||
| return PyPassManager(run_only_once) | |||
| return PyPassManager(requires_grad, run_only_once) | |||
| def unregiste_pass(py_pass): | |||
| """ | |||
| @@ -157,7 +162,7 @@ def cancel_new_parameter(pattern): | |||
| ppm = PyPassManager() | |||
| ppm.unregiste(pattern.para_name) | |||
| def set_renorm(should_renorm): | |||
| def _set_renorm(should_renorm): | |||
| """ | |||
| Set whether or not to do renormalization after modified graph in python pass(es). | |||
| @@ -171,7 +176,7 @@ def set_renorm(should_renorm): | |||
| ppm = PyPassManager() | |||
| ppm.set_renorm(should_renorm) | |||
| def set_reopt(do_reopt): | |||
| def _set_reopt(do_reopt): | |||
| """ | |||
| Set whether or not to do optimization after modified graph in python pass(es). | |||
| @@ -19,8 +19,8 @@ import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\ | |||
| cancel_new_parameter, set_reopt | |||
| from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, _set_renorm, gen_new_parameter,\ | |||
| cancel_new_parameter, _set_reopt | |||
| from mindspore.common.api import _generate_pip_args | |||
| from mindspore._c_expression import generate_key, Executor_ | |||
| from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm | |||
| @@ -157,8 +157,8 @@ def test_isnot_pattern_0(): | |||
| Test IsNot pattern which expresses the IsNot semantics. | |||
| Case: IsNot pass failed to match | |||
| """ | |||
| set_renorm(False) | |||
| set_reopt(False) | |||
| _set_renorm(False) | |||
| _set_reopt(False) | |||
| class ConvBN(nn.Cell): | |||
| def __init__(self): | |||
| super(ConvBN, self).__init__() | |||
| @@ -176,7 +176,7 @@ def test_isnot_pattern_0(): | |||
| inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32) | |||
| conv_bn_model = ConvBN() | |||
| @registe_pass(run_only_once=True) | |||
| @registe_pass(requires_grad=False, run_only_once=True) | |||
| def single_bn_pass(): | |||
| """ | |||
| Sub a BN which does NOT take Conv as inputs to ReLU6. | |||
| @@ -188,7 +188,7 @@ def test_isnot_pattern_0(): | |||
| target = Call(P.ReLU6(), [pattern_0]) | |||
| return pattern, target | |||
| @registe_pass(run_only_once=True) | |||
| @registe_pass(requires_grad=False, run_only_once=True) | |||
| def bn_pass(): | |||
| """ | |||
| Sub a BN to Softmax. | |||
| @@ -202,7 +202,7 @@ def test_isnot_pattern_0(): | |||
| unregiste_pass(bn_pass) | |||
| assert "ReLU6" not in transformed_repr | |||
| assert "Softmax" in transformed_repr | |||
| set_renorm(True) | |||
| _set_renorm(True) | |||
| def test_isnot_pattern_1(): | |||
| """ | |||
| @@ -234,12 +234,12 @@ def test_newtensor_pattern(): | |||
| """ | |||
| Test NewTensor pattern in the target | |||
| """ | |||
| set_renorm(False) | |||
| set_reopt(False) | |||
| _set_renorm(False) | |||
| _set_reopt(False) | |||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||
| softmax_model = nn.Softmax() | |||
| @registe_pass(run_only_once=True) | |||
| @registe_pass(requires_grad=False, run_only_once=True) | |||
| def softmax_addn_pass(): | |||
| x = Any() | |||
| pattern = Call(P.Softmax(), [x]) | |||
| @@ -252,7 +252,7 @@ def test_newtensor_pattern(): | |||
| unregiste_pass(softmax_addn_pass) | |||
| assert "AddN" in transformed_repr | |||
| assert "Softmax" not in transformed_repr | |||
| set_renorm(True) | |||
| _set_renorm(True) | |||
| def test_newparameter_pattern(): | |||
| """ | |||
| @@ -261,9 +261,9 @@ def test_newparameter_pattern(): | |||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||
| softmax_model = nn.Softmax() | |||
| set_renorm(False) | |||
| set_reopt(False) | |||
| @registe_pass(run_only_once=True) | |||
| _set_renorm(False) | |||
| _set_reopt(False) | |||
| @registe_pass(requires_grad=False, run_only_once=True) | |||
| def softmax_addn_pass(): | |||
| x = Any() | |||
| pattern = Call(P.Softmax(), [x]) | |||
| @@ -288,9 +288,9 @@ def test_imm_target(): | |||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||
| softmax_model = nn.Softmax() | |||
| set_renorm(False) | |||
| set_reopt(False) | |||
| @registe_pass(run_only_once=True) | |||
| _set_renorm(False) | |||
| _set_reopt(False) | |||
| @registe_pass(requires_grad=False, run_only_once=True) | |||
| def softmax_pass(): | |||
| x = Any() | |||
| pattern = Call(P.Softmax(), [x]) | |||
| @@ -313,10 +313,10 @@ def test_gen_new_parameter(): | |||
| default_tensor = Tensor(np.ones((4, 4)), mindspore.float32) | |||
| new_para = NewParameter("Merlin", default_tensor) | |||
| set_renorm(False) | |||
| set_reopt(False) | |||
| _set_renorm(False) | |||
| _set_reopt(False) | |||
| gen_new_parameter(new_para) | |||
| @registe_pass(run_only_once=True) | |||
| @registe_pass(requires_grad=False, run_only_once=True) | |||
| def softmax_make_tuple_pass(): | |||
| x = Any() | |||
| softmax = P.Softmax() | |||