| @@ -59,6 +59,7 @@ class Pattern : public Base { | |||||
| string unique_name() const { return unique_name_; } | string unique_name() const { return unique_name_; } | ||||
| vector<PatternPtr> inputs() { return inputs_; } | vector<PatternPtr> inputs() { return inputs_; } | ||||
| virtual void reset() {} | virtual void reset() {} | ||||
| static void reset_gid() { g_id_ = 0; } | |||||
| protected: | protected: | ||||
| static int g_id_; | static int g_id_; | ||||
| @@ -213,7 +214,6 @@ class NewParameter : public Pattern { | |||||
| explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel) | explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel) | ||||
| : para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) { | : para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) { | ||||
| unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name; | unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name; | ||||
| // clone input tensor | |||||
| default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get()); | default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get()); | ||||
| built_ = false; | built_ = false; | ||||
| } | } | ||||
| @@ -257,7 +257,7 @@ class MatchResult { | |||||
| MatchResult() {} | MatchResult() {} | ||||
| ~MatchResult() = default; | ~MatchResult() = default; | ||||
| void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; } | void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; } | ||||
| PatternNodeMap _result() { return match_result_; } | |||||
| PatternNodeMap &_result() { return match_result_; } | |||||
| AnfNodePtr get_node(const PatternPtr &pattern); | AnfNodePtr get_node(const PatternPtr &pattern); | ||||
| void merge(const MatchResultPtr &other_result); | void merge(const MatchResultPtr &other_result); | ||||
| void clear() { match_result_.clear(); } | void clear() { match_result_.clear(); } | ||||
| @@ -27,8 +27,6 @@ | |||||
| #include "pipeline/jit/resource.h" | #include "pipeline/jit/resource.h" | ||||
| #include "frontend/optimizer/py_pass_manager.h" | #include "frontend/optimizer/py_pass_manager.h" | ||||
| #include "utils/info.h" | #include "utils/info.h" | ||||
| #include "debug/anf_ir_dump.h" | |||||
| #include "debug/draw.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -42,29 +40,6 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph | |||||
| void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input, | void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input, | ||||
| bool requires_grad, bool layerwise_parallel); | bool requires_grad, bool layerwise_parallel); | ||||
| std::string GetNodeRepr(AnfNodePtr node) { | |||||
| if (node != nullptr) { | |||||
| if (node->isa<CNode>()) { | |||||
| std::string repr = "("; | |||||
| auto const &inputs = node->cast<CNodePtr>()->inputs(); | |||||
| for (auto &input : inputs) { | |||||
| repr += " "; | |||||
| repr += GetNodeRepr(input); | |||||
| repr += " "; | |||||
| } | |||||
| repr += ")"; | |||||
| return repr; | |||||
| } | |||||
| if (node->isa<Parameter>()) { | |||||
| return "[Parameter]" + node->ToString(); | |||||
| } else if (node->isa<ValueNode>()) { | |||||
| return "[Value]" + GetValueNode(node)->ToString(); | |||||
| } | |||||
| return node->ToString(); | |||||
| } | |||||
| return ""; | |||||
| } | |||||
| bool IsTraversable(const AnfNodePtr &node) { | bool IsTraversable(const AnfNodePtr &node) { | ||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| return false; | return false; | ||||
| @@ -215,23 +190,6 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph | |||||
| return new_node; | return new_node; | ||||
| } | } | ||||
| void DrawNode(string name, AnfNodePtr node) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||||
| if (save_graphs_path.empty()) { | |||||
| save_graphs_path = "."; | |||||
| } | |||||
| auto new_func_graph = std::make_shared<FuncGraph>(); | |||||
| new_func_graph->set_output(node, true); | |||||
| if (save_graphs) { | |||||
| auto ir_dump_path = save_graphs_path + "/" + name + ".ir"; | |||||
| auto dot_dump_path = save_graphs_path + "/" + name + ".dot"; | |||||
| DumpIR(ir_dump_path, new_func_graph); | |||||
| draw::Draw(dot_dump_path, new_func_graph); | |||||
| } | |||||
| } | |||||
| void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input, | void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input, | ||||
| bool requires_grad, bool layerwise_parallel) { | bool requires_grad, bool layerwise_parallel) { | ||||
| // 1. Get current cell object | // 1. Get current cell object | ||||
| @@ -241,12 +199,15 @@ void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor | |||||
| if (py::isinstance<py::none>(top_cell)) { | if (py::isinstance<py::none>(top_cell)) { | ||||
| MS_LOG(EXCEPTION) << "Failed to get top cell from resource."; | MS_LOG(EXCEPTION) << "Failed to get top cell from resource."; | ||||
| } | } | ||||
| // 2. New a Parameter object with the above-specified args | |||||
| // 2. Clone default_input tensor | |||||
| auto default_tensor = std::make_shared<tensor::Tensor>(default_input->data_type(), default_input->shape_c(), | |||||
| default_input->data_c(), (size_t)default_input->Size()); | |||||
| // 3. New a Parameter object with the above-specified args | |||||
| py::object parameter_class = py::module::import(PARAMETER_MODULE).attr(PARAMETER_CLASS); | py::object parameter_class = py::module::import(PARAMETER_MODULE).attr(PARAMETER_CLASS); | ||||
| py::object new_parameter = parameter_class(default_input, param_name, requires_grad, layerwise_parallel); | |||||
| // 3. Add the new python Parameter object to Cell's _params atttributes | |||||
| py::object new_parameter = parameter_class(default_tensor, param_name, requires_grad, layerwise_parallel); | |||||
| // 4. Add the new python Parameter object to Cell's _params atttributes | |||||
| top_cell.attr(SET_PARAM)(param_name, new_parameter); | top_cell.attr(SET_PARAM)(param_name, new_parameter); | ||||
| // 4. Set default_param for param_node | |||||
| // 5. Set default_param for param_node | |||||
| ValuePtr param_value = nullptr; | ValuePtr param_value = nullptr; | ||||
| bool converted = parse::ConvertData(new_parameter, ¶m_value, false); | bool converted = parse::ConvertData(new_parameter, ¶m_value, false); | ||||
| if (!converted) { | if (!converted) { | ||||
| @@ -282,11 +243,9 @@ void Reset(PatternPtr pattern) { | |||||
| AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) { | AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) { | ||||
| auto match_res = src_pattern_->match(node); | auto match_res = src_pattern_->match(node); | ||||
| if (match_res != nullptr) { | if (match_res != nullptr) { | ||||
| MS_LOG(DEBUG) << "Matched pattern: " + src_pattern_->unique_name() + " node : " + internal::GetNodeRepr(node); | |||||
| res->merge(match_res); | res->merge(match_res); | ||||
| auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res); | auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res); | ||||
| internal::Reset(dst_pattern()); | internal::Reset(dst_pattern()); | ||||
| MS_LOG(WARNING) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; | |||||
| return new_node; | return new_node; | ||||
| } | } | ||||
| internal::Reset(src_pattern()); | internal::Reset(src_pattern()); | ||||
| @@ -303,7 +262,6 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) | |||||
| MS_LOG(EXCEPTION) << "Expect NewParameter pattern for target if src pattern is null."; | MS_LOG(EXCEPTION) << "Expect NewParameter pattern for target if src pattern is null."; | ||||
| } | } | ||||
| auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name(); | auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name(); | ||||
| MS_LOG(DEBUG) << "Adding New parameter : " + para_name; | |||||
| auto para_node = std::make_shared<Parameter>(func_graph); | auto para_node = std::make_shared<Parameter>(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(para_node); | MS_EXCEPTION_IF_NULL(para_node); | ||||
| para_node->set_name(para_name); | para_node->set_name(para_name); | ||||
| @@ -321,7 +279,7 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) | |||||
| // Reflect back to Cell._params | // Reflect back to Cell._params | ||||
| internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(), | internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(), | ||||
| new_para_pattern->layerwise_parallel()); | new_para_pattern->layerwise_parallel()); | ||||
| MS_LOG(WARNING) << "Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name(); | |||||
| MS_LOG(WARNING) << "[Gen]Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name(); | |||||
| return true; | return true; | ||||
| } | } | ||||
| FuncGraphManagerPtr manager = func_graph->manager(); | FuncGraphManagerPtr manager = func_graph->manager(); | ||||
| @@ -334,7 +292,6 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) | |||||
| for (auto &node : graph_nodes_sorted) { | for (auto &node : graph_nodes_sorted) { | ||||
| AnfNodePtr new_node = Run(func_graph, node, res); | AnfNodePtr new_node = Run(func_graph, node, res); | ||||
| if (new_node != nullptr && new_node != node) { | if (new_node != nullptr && new_node != node) { | ||||
| internal::DrawNode(dst_pattern_->unique_name(), new_node); | |||||
| (void)manager->Replace(node, new_node); | (void)manager->Replace(node, new_node); | ||||
| changes = true; | changes = true; | ||||
| } | } | ||||
| @@ -98,7 +98,8 @@ REGISTER_PYBIND_DEFINE( | |||||
| .def("registe", &PyPassManager::Registe, "Registe python pass") | .def("registe", &PyPassManager::Registe, "Registe python pass") | ||||
| .def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass") | .def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass") | ||||
| .def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter") | .def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter") | ||||
| .def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph"); | |||||
| .def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph") | |||||
| .def("set_reopt", &PyPassManager::SetReOpt, "Set whether or not to do optimization after modified graph"); | |||||
| })); | })); | ||||
| } // namespace python_pass | } // namespace python_pass | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -60,13 +60,19 @@ class PyPassManager { | |||||
| MatchResultPtr GetMatchResult() { return res_; } | MatchResultPtr GetMatchResult() { return res_; } | ||||
| void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; } | void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; } | ||||
| bool ShouldRenorm() { return should_renorm_; } | bool ShouldRenorm() { return should_renorm_; } | ||||
| void SetReOpt(bool should_reopt) { should_reopt_ = should_reopt; } | |||||
| bool ShouldReOpt() { return should_reopt_; } | |||||
| void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; } | void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; } | ||||
| pipeline::ResourcePtr GetResource() { return resource_; } | pipeline::ResourcePtr GetResource() { return resource_; } | ||||
| void ClearRes(); | void ClearRes(); | ||||
| void ClearPipelineRes() { resource_ = nullptr; } | |||||
| void ClearPipelineRes() { | |||||
| resource_ = nullptr; | |||||
| Pattern::reset_gid(); | |||||
| } | |||||
| private: | private: | ||||
| bool should_renorm_ = true; | bool should_renorm_ = true; | ||||
| bool should_reopt_ = true; | |||||
| MatchResultPtr res_; | MatchResultPtr res_; | ||||
| pipeline::ResourcePtr resource_; | pipeline::ResourcePtr resource_; | ||||
| static std::unordered_map<Phase, PassGroupPtr> phase_to_group_; | static std::unordered_map<Phase, PassGroupPtr> phase_to_group_; | ||||
| @@ -451,35 +451,55 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) { | |||||
| bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); } | bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); } | ||||
| void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) { | |||||
| bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) { | |||||
| MS_EXCEPTION_IF_NULL(res->manager()); | MS_EXCEPTION_IF_NULL(res->manager()); | ||||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | MS_EXCEPTION_IF_NULL(res->func_graph()); | ||||
| auto ppm = opt::python_pass::PyPassManager::GetInstance(); | auto ppm = opt::python_pass::PyPassManager::GetInstance(); | ||||
| ppm->SetResource(res); | ppm->SetResource(res); | ||||
| if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) { | |||||
| MS_LOG(DEBUG) << "No match.\n"; | |||||
| } else if (phase == opt::python_pass::Phase::OPT && opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) { | |||||
| MS_LOG(DEBUG) << "Entered PyStub Renorm"; | |||||
| // Renomalize | |||||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||||
| FuncGraphPtr func_graph = res->func_graph(); | |||||
| abstract::AbstractBasePtrList args_spec; | |||||
| auto parameters = func_graph->parameters(); | |||||
| (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), | |||||
| [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); | |||||
| FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); | |||||
| res->set_func_graph(new_fg); | |||||
| res->set_args_spec(args_spec); | |||||
| } | |||||
| return ppm->GetPassGroup(phase)->Run(res->func_graph()); | |||||
| } | } | ||||
| bool ResolveActionPyStub(const ResourcePtr &res) { | |||||
| ActionPyStub(res, opt::python_pass::Phase::RESOLVE); | |||||
| bool ResolveActionPyStub(const ResourcePtr &res) { return true || ActionPyStub(res, opt::python_pass::Phase::RESOLVE); } | |||||
| bool OptActionVmPyStub(const ResourcePtr &res) { | |||||
| if (ActionPyStub(res, opt::python_pass::Phase::OPT)) { | |||||
| if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) { | |||||
| // Renomalize | |||||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||||
| FuncGraphPtr func_graph = res->func_graph(); | |||||
| abstract::AbstractBasePtrList args_spec; | |||||
| auto parameters = func_graph->parameters(); | |||||
| (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), | |||||
| [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); | |||||
| FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); | |||||
| res->set_func_graph(new_fg); | |||||
| res->set_args_spec(args_spec); | |||||
| } | |||||
| if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) { | |||||
| return VmOptimizeAction(res); | |||||
| } | |||||
| } | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool OptActionPyStub(const ResourcePtr &res) { | |||||
| ActionPyStub(res, opt::python_pass::Phase::OPT); | |||||
| bool OptActionGePyStub(const ResourcePtr &res) { | |||||
| if (ActionPyStub(res, opt::python_pass::Phase::OPT)) { | |||||
| if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) { | |||||
| // Renomalize | |||||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||||
| FuncGraphPtr func_graph = res->func_graph(); | |||||
| abstract::AbstractBasePtrList args_spec; | |||||
| auto parameters = func_graph->parameters(); | |||||
| (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), | |||||
| [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); | |||||
| FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); | |||||
| res->set_func_graph(new_fg); | |||||
| res->set_args_spec(args_spec); | |||||
| } | |||||
| if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) { | |||||
| return GeOptimizeAction(res); | |||||
| } | |||||
| } | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -510,7 +530,7 @@ std::vector<ActionItem> GePipeline() { | |||||
| // optimize | // optimize | ||||
| actions.emplace_back(std::make_pair("optimize", GeOptimizeAction)); | actions.emplace_back(std::make_pair("optimize", GeOptimizeAction)); | ||||
| // Add opt-stage python pass stub | // Add opt-stage python pass stub | ||||
| actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); | |||||
| actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub)); | |||||
| actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction)); | actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction)); | ||||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | actions.emplace_back(std::make_pair("validate", ValidateAction)); | ||||
| return actions; | return actions; | ||||
| @@ -523,7 +543,7 @@ std::vector<ActionItem> VmPipeline() { | |||||
| actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); | actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); | ||||
| // Add opt-stage python pass stub | // Add opt-stage python pass stub | ||||
| actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); | |||||
| actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub)); | |||||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | actions.emplace_back(std::make_pair("validate", ValidateAction)); | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | ||||
| @@ -12,13 +12,15 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Top-level reference to python pass.""" | |||||
| from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm | |||||
| """Reference for python pass registration.""" | |||||
| from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\ | |||||
| set_reopt | |||||
| __all__ = [ | __all__ = [ | ||||
| "registe_pass", | "registe_pass", | ||||
| "unregiste_pass", | "unregiste_pass", | ||||
| "gen_new_parameter", | "gen_new_parameter", | ||||
| "cancel_new_parameter", | "cancel_new_parameter", | ||||
| "set_renorm" | |||||
| "set_renorm", | |||||
| "set_reopt" | |||||
| ] | ] | ||||
| @@ -23,7 +23,8 @@ __all__ = [ | |||||
| "unregiste_pass", | "unregiste_pass", | ||||
| "gen_new_parameter", | "gen_new_parameter", | ||||
| "cancel_new_parameter", | "cancel_new_parameter", | ||||
| "set_renorm" | |||||
| "set_renorm", | |||||
| "set_reopt" | |||||
| ] | ] | ||||
| class PyPassManager(PyPassManager_): | class PyPassManager(PyPassManager_): | ||||
| r""" | r""" | ||||
| @@ -75,6 +76,11 @@ class PyPassManager(PyPassManager_): | |||||
| raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}") | raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}") | ||||
| super().set_renorm(should_renorm) | super().set_renorm(should_renorm) | ||||
| def set_reopt(self, do_reopt): | |||||
| if not isinstance(do_reopt, bool): | |||||
| raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}") | |||||
| super().set_reopt(do_reopt) | |||||
| def registe_pass(run_only_once=False): | def registe_pass(run_only_once=False): | ||||
| """ | """ | ||||
| Registe python pass to specified pipeline phase which would be used in compilation. | Registe python pass to specified pipeline phase which would be used in compilation. | ||||
| @@ -164,3 +170,17 @@ def set_renorm(should_renorm): | |||||
| """ | """ | ||||
| ppm = PyPassManager() | ppm = PyPassManager() | ||||
| ppm.set_renorm(should_renorm) | ppm.set_renorm(should_renorm) | ||||
| def set_reopt(do_reopt): | |||||
| """ | |||||
| Set whether or not to do optimization after modified graph in python pass(es). | |||||
| Args: | |||||
| do_reopt(bool): whether or not to do optimization after modified graph in python pass(es). | |||||
| NOTE: | |||||
| This interface is mainly intended for testing modifying graph without worrying about its validity. Turn off | |||||
| renormalization may BREAK the network. | |||||
| """ | |||||
| ppm = PyPassManager() | |||||
| ppm.set_reopt(do_reopt) | |||||
| @@ -20,7 +20,7 @@ from mindspore import context | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\ | from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\ | ||||
| cancel_new_parameter | |||||
| cancel_new_parameter, set_reopt | |||||
| from mindspore.common.api import _generate_pip_args | from mindspore.common.api import _generate_pip_args | ||||
| from mindspore._c_expression import generate_key, Executor_ | from mindspore._c_expression import generate_key, Executor_ | ||||
| from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm | from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm | ||||
| @@ -50,8 +50,8 @@ def test_softmax_relu(): | |||||
| @registe_pass(run_only_once=True) | @registe_pass(run_only_once=True) | ||||
| def softmax_relu_pass(): | def softmax_relu_pass(): | ||||
| x = Any() | x = Any() | ||||
| pattern = Call(P.Softmax(), inputs=[x]) | |||||
| target = Call(P.ReLU(), inputs=[x]) | |||||
| pattern = Call(P.Softmax(), [x]) | |||||
| target = Call(P.ReLU(), [x]) | |||||
| return pattern, target | return pattern, target | ||||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | ||||
| @@ -59,6 +59,23 @@ def test_softmax_relu(): | |||||
| assert "ReLU" in transformed_repr | assert "ReLU" in transformed_repr | ||||
| assert "Softmax" not in transformed_repr | assert "Softmax" not in transformed_repr | ||||
| def test_prim(): | |||||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||||
| softmax_model = nn.Softmax() | |||||
| @registe_pass(run_only_once=True) | |||||
| def softmax_relu_pass(): | |||||
| x = Any() | |||||
| sigmoid_softmax_pattern = Prim([P.Sigmoid(), P.Softmax()]) | |||||
| pattern = Call(sigmoid_softmax_pattern, [x]) | |||||
| target = Call(P.ReLU(), [x]) | |||||
| return pattern, target | |||||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3) | |||||
| unregiste_pass(softmax_relu_pass) | |||||
| assert "ReLU" in transformed_repr | |||||
| assert "Softmax" not in transformed_repr | |||||
| def test_softmax_relu_sigmoid(): | def test_softmax_relu_sigmoid(): | ||||
| """ | """ | ||||
| Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)). | Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)). | ||||
| @@ -73,11 +90,11 @@ def test_softmax_relu_sigmoid(): | |||||
| def softmax_relu_pass(): | def softmax_relu_pass(): | ||||
| x = Any() | x = Any() | ||||
| softmax_pattern = Prim(P.Softmax()) | softmax_pattern = Prim(P.Softmax()) | ||||
| pattern = Call(softmax_pattern, inputs=[x]) | |||||
| pattern = Call(softmax_pattern, [x]) | |||||
| sigmoid_pattern = Prim(P.Sigmoid()) | sigmoid_pattern = Prim(P.Sigmoid()) | ||||
| call_sigmoid = Call(sigmoid_pattern, [x]) | call_sigmoid = Call(sigmoid_pattern, [x]) | ||||
| relu_pattern = Prim(P.ReLU()) | relu_pattern = Prim(P.ReLU()) | ||||
| target = Call(relu_pattern, inputs=[call_sigmoid]) | |||||
| target = Call(relu_pattern, [call_sigmoid]) | |||||
| return pattern, target | return pattern, target | ||||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3) | transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3) | ||||
| @@ -98,13 +115,13 @@ def test_isin_pattern_0(): | |||||
| def softmax_relu_pass(): | def softmax_relu_pass(): | ||||
| x = Any() | x = Any() | ||||
| softmax_pattern = Prim(P.Softmax()) | softmax_pattern = Prim(P.Softmax()) | ||||
| call_softmax = Call(softmax_pattern, inputs=[x]) | |||||
| call_softmax = Call(softmax_pattern, [x]) | |||||
| relu_pattern = Prim(P.ReLU()) | relu_pattern = Prim(P.ReLU()) | ||||
| call_relu = Call(relu_pattern, inputs=[x]) | |||||
| call_relu = Call(relu_pattern, [x]) | |||||
| pattern = OneOf([call_softmax, call_relu]) | pattern = OneOf([call_softmax, call_relu]) | ||||
| relu6_pattern = Prim(P.ReLU6()) | relu6_pattern = Prim(P.ReLU6()) | ||||
| target = Call(relu6_pattern, inputs=[x]) | |||||
| target = Call(relu6_pattern, [x]) | |||||
| return pattern, target | return pattern, target | ||||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | ||||
| unregiste_pass(softmax_relu_pass) | unregiste_pass(softmax_relu_pass) | ||||
| @@ -122,13 +139,13 @@ def test_isin_pattern_1(): | |||||
| def softmax_neg_pass(): | def softmax_neg_pass(): | ||||
| x = Any() | x = Any() | ||||
| softmax_pattern = Prim(P.Softmax()) | softmax_pattern = Prim(P.Softmax()) | ||||
| call_softmax = Call(softmax_pattern, inputs=[x]) | |||||
| call_softmax = Call(softmax_pattern, [x]) | |||||
| relu_pattern = Prim(P.ReLU()) | relu_pattern = Prim(P.ReLU()) | ||||
| call_relu = Call(relu_pattern, inputs=[x]) | |||||
| call_relu = Call(relu_pattern, [x]) | |||||
| pattern = OneOf([call_softmax, call_relu]) | pattern = OneOf([call_softmax, call_relu]) | ||||
| neg_ops = Prim(P.Neg()) | neg_ops = Prim(P.Neg()) | ||||
| target = Call(neg_ops, inputs=[pattern]) | |||||
| target = Call(neg_ops, [pattern]) | |||||
| return pattern, target | return pattern, target | ||||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4) | transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4) | ||||
| unregiste_pass(softmax_neg_pass) | unregiste_pass(softmax_neg_pass) | ||||
| @@ -141,6 +158,7 @@ def test_isnot_pattern_0(): | |||||
| Case: IsNot pass failed to match | Case: IsNot pass failed to match | ||||
| """ | """ | ||||
| set_renorm(False) | set_renorm(False) | ||||
| set_reopt(False) | |||||
| class ConvBN(nn.Cell): | class ConvBN(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(ConvBN, self).__init__() | super(ConvBN, self).__init__() | ||||
| @@ -166,8 +184,8 @@ def test_isnot_pattern_0(): | |||||
| conv2d_prim = Prim("Conv2D") | conv2d_prim = Prim("Conv2D") | ||||
| conv2d = Call(conv2d_prim) | conv2d = Call(conv2d_prim) | ||||
| pattern_0 = NoneOf(conv2d) | pattern_0 = NoneOf(conv2d) | ||||
| pattern = Call(P.BatchNorm(), inputs=[pattern_0]) | |||||
| target = Call(P.ReLU6(), inputs=[pattern_0]) | |||||
| pattern = Call(P.BatchNorm(), [pattern_0]) | |||||
| target = Call(P.ReLU6(), [pattern_0]) | |||||
| return pattern, target | return pattern, target | ||||
| @registe_pass(run_only_once=True) | @registe_pass(run_only_once=True) | ||||
| @@ -202,9 +220,9 @@ def test_isnot_pattern_1(): | |||||
| matmul = Prim("MatMul") | matmul = Prim("MatMul") | ||||
| pattern_0 = NoneOf(matmul) | pattern_0 = NoneOf(matmul) | ||||
| softmax = P.Softmax() | softmax = P.Softmax() | ||||
| pattern = Call(softmax, inputs=[pattern_0]) | |||||
| pattern = Call(softmax, [pattern_0]) | |||||
| relu6 = P.ReLU6() | relu6 = P.ReLU6() | ||||
| target = Call(relu6, inputs=[pattern_0]) | |||||
| target = Call(relu6, [pattern_0]) | |||||
| return pattern, target | return pattern, target | ||||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | ||||
| @@ -217,17 +235,18 @@ def test_newtensor_pattern(): | |||||
| Test NewTensor pattern in the target | Test NewTensor pattern in the target | ||||
| """ | """ | ||||
| set_renorm(False) | set_renorm(False) | ||||
| set_reopt(False) | |||||
| inputs = Tensor(np.ones([42]), mindspore.float16) | inputs = Tensor(np.ones([42]), mindspore.float16) | ||||
| softmax_model = nn.Softmax() | softmax_model = nn.Softmax() | ||||
| @registe_pass(run_only_once=True) | @registe_pass(run_only_once=True) | ||||
| def softmax_addn_pass(): | def softmax_addn_pass(): | ||||
| x = Any() | x = Any() | ||||
| pattern = Call(P.Softmax(), inputs=[x]) | |||||
| pattern = Call(P.Softmax(), [x]) | |||||
| weight_tensor = Tensor(np.zeros([42]), mindspore.float16) | weight_tensor = Tensor(np.zeros([42]), mindspore.float16) | ||||
| new_weight = NewTensor(weight_tensor) | new_weight = NewTensor(weight_tensor) | ||||
| target = Call(P.AddN(), inputs=[x, new_weight]) | |||||
| target = Call(P.AddN(), [x, new_weight]) | |||||
| return pattern, target | return pattern, target | ||||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | ||||
| unregiste_pass(softmax_addn_pass) | unregiste_pass(softmax_addn_pass) | ||||
| @@ -242,17 +261,19 @@ def test_newparameter_pattern(): | |||||
| inputs = Tensor(np.ones([42]), mindspore.float16) | inputs = Tensor(np.ones([42]), mindspore.float16) | ||||
| softmax_model = nn.Softmax() | softmax_model = nn.Softmax() | ||||
| set_renorm(False) | |||||
| set_reopt(False) | |||||
| @registe_pass(run_only_once=True) | @registe_pass(run_only_once=True) | ||||
| def softmax_addn_pass(): | def softmax_addn_pass(): | ||||
| x = Any() | x = Any() | ||||
| pattern = Call(P.Softmax(), inputs=[x]) | |||||
| pattern = Call(P.Softmax(), [x]) | |||||
| default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32) | default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32) | ||||
| default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32) | default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32) | ||||
| new_para_0 = NewParameter("Merlin", default_tensor0) | new_para_0 = NewParameter("Merlin", default_tensor0) | ||||
| new_para_1 = NewParameter("Arthur", default_tensor1) | new_para_1 = NewParameter("Arthur", default_tensor1) | ||||
| target_0 = Call(P.MatMul(), inputs=[new_para_0, new_para_1]) | |||||
| target = Call("make_tuple", inputs=[target_0]) | |||||
| target_0 = Call(P.MatMul(), [new_para_0, new_para_1]) | |||||
| target = Call("make_tuple", [target_0]) | |||||
| return pattern, target | return pattern, target | ||||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | ||||
| unregiste_pass(softmax_addn_pass) | unregiste_pass(softmax_addn_pass) | ||||
| @@ -267,13 +288,15 @@ def test_imm_target(): | |||||
| inputs = Tensor(np.ones([42]), mindspore.float16) | inputs = Tensor(np.ones([42]), mindspore.float16) | ||||
| softmax_model = nn.Softmax() | softmax_model = nn.Softmax() | ||||
| set_renorm(False) | |||||
| set_reopt(False) | |||||
| @registe_pass(run_only_once=True) | @registe_pass(run_only_once=True) | ||||
| def softmax_pass(): | def softmax_pass(): | ||||
| x = Any() | x = Any() | ||||
| pattern = Call(P.Softmax(), inputs=[x]) | |||||
| pattern = Call(P.Softmax(), [x]) | |||||
| imm = Imm(0) | imm = Imm(0) | ||||
| target_0 = Call("make_tuple", inputs=[pattern]) | |||||
| target = Call("tuple_getitem", inputs=[target_0, imm]) | |||||
| target_0 = Call("make_tuple", [pattern]) | |||||
| target = Call("tuple_getitem", [target_0, imm]) | |||||
| return pattern, target | return pattern, target | ||||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | ||||
| unregiste_pass(softmax_pass) | unregiste_pass(softmax_pass) | ||||
| @@ -290,14 +313,16 @@ def test_gen_new_parameter(): | |||||
| default_tensor = Tensor(np.ones((4, 4)), mindspore.float32) | default_tensor = Tensor(np.ones((4, 4)), mindspore.float32) | ||||
| new_para = NewParameter("Merlin", default_tensor) | new_para = NewParameter("Merlin", default_tensor) | ||||
| set_renorm(False) | |||||
| set_reopt(False) | |||||
| gen_new_parameter(new_para) | gen_new_parameter(new_para) | ||||
| @registe_pass(run_only_once=True) | @registe_pass(run_only_once=True) | ||||
| def softmax_make_tuple_pass(): | def softmax_make_tuple_pass(): | ||||
| x = Any() | x = Any() | ||||
| softmax = P.Softmax() | softmax = P.Softmax() | ||||
| pattern = Call(softmax, inputs=[x]) | |||||
| pattern = Call(softmax, [x]) | |||||
| target = Call("make_tuple", inputs=[pattern, new_para]) | |||||
| target = Call("make_tuple", [pattern, new_para]) | |||||
| return pattern, target | return pattern, target | ||||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | ||||
| assert "Merlin" in transformed_repr | assert "Merlin" in transformed_repr | ||||