| @@ -59,6 +59,7 @@ class Pattern : public Base { | |||
| string unique_name() const { return unique_name_; } | |||
| vector<PatternPtr> inputs() { return inputs_; } | |||
| virtual void reset() {} | |||
| static void reset_gid() { g_id_ = 0; } | |||
| protected: | |||
| static int g_id_; | |||
| @@ -213,7 +214,6 @@ class NewParameter : public Pattern { | |||
| explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel) | |||
| : para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) { | |||
| unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name; | |||
| // clone input tensor | |||
| default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get()); | |||
| built_ = false; | |||
| } | |||
| @@ -257,7 +257,7 @@ class MatchResult { | |||
| MatchResult() {} | |||
| ~MatchResult() = default; | |||
| void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; } | |||
| PatternNodeMap _result() { return match_result_; } | |||
| PatternNodeMap &_result() { return match_result_; } | |||
| AnfNodePtr get_node(const PatternPtr &pattern); | |||
| void merge(const MatchResultPtr &other_result); | |||
| void clear() { match_result_.clear(); } | |||
| @@ -27,8 +27,6 @@ | |||
| #include "pipeline/jit/resource.h" | |||
| #include "frontend/optimizer/py_pass_manager.h" | |||
| #include "utils/info.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "debug/draw.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -42,29 +40,6 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph | |||
| void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input, | |||
| bool requires_grad, bool layerwise_parallel); | |||
| std::string GetNodeRepr(AnfNodePtr node) { | |||
| if (node != nullptr) { | |||
| if (node->isa<CNode>()) { | |||
| std::string repr = "("; | |||
| auto const &inputs = node->cast<CNodePtr>()->inputs(); | |||
| for (auto &input : inputs) { | |||
| repr += " "; | |||
| repr += GetNodeRepr(input); | |||
| repr += " "; | |||
| } | |||
| repr += ")"; | |||
| return repr; | |||
| } | |||
| if (node->isa<Parameter>()) { | |||
| return "[Parameter]" + node->ToString(); | |||
| } else if (node->isa<ValueNode>()) { | |||
| return "[Value]" + GetValueNode(node)->ToString(); | |||
| } | |||
| return node->ToString(); | |||
| } | |||
| return ""; | |||
| } | |||
| bool IsTraversable(const AnfNodePtr &node) { | |||
| if (node == nullptr) { | |||
| return false; | |||
| @@ -215,23 +190,6 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph | |||
| return new_node; | |||
| } | |||
| void DrawNode(string name, AnfNodePtr node) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| auto new_func_graph = std::make_shared<FuncGraph>(); | |||
| new_func_graph->set_output(node, true); | |||
| if (save_graphs) { | |||
| auto ir_dump_path = save_graphs_path + "/" + name + ".ir"; | |||
| auto dot_dump_path = save_graphs_path + "/" + name + ".dot"; | |||
| DumpIR(ir_dump_path, new_func_graph); | |||
| draw::Draw(dot_dump_path, new_func_graph); | |||
| } | |||
| } | |||
| void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input, | |||
| bool requires_grad, bool layerwise_parallel) { | |||
| // 1. Get current cell object | |||
| @@ -241,12 +199,15 @@ void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor | |||
| if (py::isinstance<py::none>(top_cell)) { | |||
| MS_LOG(EXCEPTION) << "Failed to get top cell from resource."; | |||
| } | |||
| // 2. New a Parameter object with the above-specified args | |||
| // 2. Clone default_input tensor | |||
| auto default_tensor = std::make_shared<tensor::Tensor>(default_input->data_type(), default_input->shape_c(), | |||
| default_input->data_c(), (size_t)default_input->Size()); | |||
| // 3. New a Parameter object with the above-specified args | |||
| py::object parameter_class = py::module::import(PARAMETER_MODULE).attr(PARAMETER_CLASS); | |||
| py::object new_parameter = parameter_class(default_input, param_name, requires_grad, layerwise_parallel); | |||
| // 3. Add the new python Parameter object to Cell's _params atttributes | |||
| py::object new_parameter = parameter_class(default_tensor, param_name, requires_grad, layerwise_parallel); | |||
| // 4. Add the new python Parameter object to Cell's _params atttributes | |||
| top_cell.attr(SET_PARAM)(param_name, new_parameter); | |||
| // 4. Set default_param for param_node | |||
| // 5. Set default_param for param_node | |||
| ValuePtr param_value = nullptr; | |||
| bool converted = parse::ConvertData(new_parameter, ¶m_value, false); | |||
| if (!converted) { | |||
| @@ -282,11 +243,9 @@ void Reset(PatternPtr pattern) { | |||
| AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) { | |||
| auto match_res = src_pattern_->match(node); | |||
| if (match_res != nullptr) { | |||
| MS_LOG(DEBUG) << "Matched pattern: " + src_pattern_->unique_name() + " node : " + internal::GetNodeRepr(node); | |||
| res->merge(match_res); | |||
| auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res); | |||
| internal::Reset(dst_pattern()); | |||
| MS_LOG(WARNING) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; | |||
| return new_node; | |||
| } | |||
| internal::Reset(src_pattern()); | |||
| @@ -303,7 +262,6 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) | |||
| MS_LOG(EXCEPTION) << "Expect NewParameter pattern for target if src pattern is null."; | |||
| } | |||
| auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name(); | |||
| MS_LOG(DEBUG) << "Adding New parameter : " + para_name; | |||
| auto para_node = std::make_shared<Parameter>(func_graph); | |||
| MS_EXCEPTION_IF_NULL(para_node); | |||
| para_node->set_name(para_name); | |||
| @@ -321,7 +279,7 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) | |||
| // Reflect back to Cell._params | |||
| internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(), | |||
| new_para_pattern->layerwise_parallel()); | |||
| MS_LOG(WARNING) << "Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name(); | |||
| MS_LOG(WARNING) << "[Gen]Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name(); | |||
| return true; | |||
| } | |||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||
| @@ -334,7 +292,6 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) | |||
| for (auto &node : graph_nodes_sorted) { | |||
| AnfNodePtr new_node = Run(func_graph, node, res); | |||
| if (new_node != nullptr && new_node != node) { | |||
| internal::DrawNode(dst_pattern_->unique_name(), new_node); | |||
| (void)manager->Replace(node, new_node); | |||
| changes = true; | |||
| } | |||
| @@ -98,7 +98,8 @@ REGISTER_PYBIND_DEFINE( | |||
| .def("registe", &PyPassManager::Registe, "Registe python pass") | |||
| .def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass") | |||
| .def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter") | |||
| .def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph"); | |||
| .def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph") | |||
| .def("set_reopt", &PyPassManager::SetReOpt, "Set whether or not to do optimization after modified graph"); | |||
| })); | |||
| } // namespace python_pass | |||
| } // namespace opt | |||
| @@ -60,13 +60,19 @@ class PyPassManager { | |||
| MatchResultPtr GetMatchResult() { return res_; } | |||
| void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; } | |||
| bool ShouldRenorm() { return should_renorm_; } | |||
| void SetReOpt(bool should_reopt) { should_reopt_ = should_reopt; } | |||
| bool ShouldReOpt() { return should_reopt_; } | |||
| void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; } | |||
| pipeline::ResourcePtr GetResource() { return resource_; } | |||
| void ClearRes(); | |||
| void ClearPipelineRes() { resource_ = nullptr; } | |||
| void ClearPipelineRes() { | |||
| resource_ = nullptr; | |||
| Pattern::reset_gid(); | |||
| } | |||
| private: | |||
| bool should_renorm_ = true; | |||
| bool should_reopt_ = true; | |||
| MatchResultPtr res_; | |||
| pipeline::ResourcePtr resource_; | |||
| static std::unordered_map<Phase, PassGroupPtr> phase_to_group_; | |||
| @@ -451,35 +451,55 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) { | |||
| bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); } | |||
| void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) { | |||
| bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) { | |||
| MS_EXCEPTION_IF_NULL(res->manager()); | |||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||
| auto ppm = opt::python_pass::PyPassManager::GetInstance(); | |||
| ppm->SetResource(res); | |||
| if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) { | |||
| MS_LOG(DEBUG) << "No match.\n"; | |||
| } else if (phase == opt::python_pass::Phase::OPT && opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) { | |||
| MS_LOG(DEBUG) << "Entered PyStub Renorm"; | |||
| // Renomalize | |||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| abstract::AbstractBasePtrList args_spec; | |||
| auto parameters = func_graph->parameters(); | |||
| (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), | |||
| [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); | |||
| FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); | |||
| res->set_func_graph(new_fg); | |||
| res->set_args_spec(args_spec); | |||
| } | |||
| return ppm->GetPassGroup(phase)->Run(res->func_graph()); | |||
| } | |||
| bool ResolveActionPyStub(const ResourcePtr &res) { | |||
| ActionPyStub(res, opt::python_pass::Phase::RESOLVE); | |||
| bool ResolveActionPyStub(const ResourcePtr &res) { return true || ActionPyStub(res, opt::python_pass::Phase::RESOLVE); } | |||
| bool OptActionVmPyStub(const ResourcePtr &res) { | |||
| if (ActionPyStub(res, opt::python_pass::Phase::OPT)) { | |||
| if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) { | |||
| // Renomalize | |||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| abstract::AbstractBasePtrList args_spec; | |||
| auto parameters = func_graph->parameters(); | |||
| (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), | |||
| [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); | |||
| FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); | |||
| res->set_func_graph(new_fg); | |||
| res->set_args_spec(args_spec); | |||
| } | |||
| if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) { | |||
| return VmOptimizeAction(res); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool OptActionPyStub(const ResourcePtr &res) { | |||
| ActionPyStub(res, opt::python_pass::Phase::OPT); | |||
| bool OptActionGePyStub(const ResourcePtr &res) { | |||
| if (ActionPyStub(res, opt::python_pass::Phase::OPT)) { | |||
| if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) { | |||
| // Renomalize | |||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| abstract::AbstractBasePtrList args_spec; | |||
| auto parameters = func_graph->parameters(); | |||
| (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), | |||
| [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); | |||
| FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); | |||
| res->set_func_graph(new_fg); | |||
| res->set_args_spec(args_spec); | |||
| } | |||
| if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) { | |||
| return GeOptimizeAction(res); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| @@ -510,7 +530,7 @@ std::vector<ActionItem> GePipeline() { | |||
| // optimize | |||
| actions.emplace_back(std::make_pair("optimize", GeOptimizeAction)); | |||
| // Add opt-stage python pass stub | |||
| actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); | |||
| actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub)); | |||
| actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction)); | |||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | |||
| return actions; | |||
| @@ -523,7 +543,7 @@ std::vector<ActionItem> VmPipeline() { | |||
| actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); | |||
| // Add opt-stage python pass stub | |||
| actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); | |||
| actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub)); | |||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| @@ -12,13 +12,15 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Top-level reference to python pass.""" | |||
| from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm | |||
| """Reference for python pass registration.""" | |||
| from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\ | |||
| set_reopt | |||
| __all__ = [ | |||
| "registe_pass", | |||
| "unregiste_pass", | |||
| "gen_new_parameter", | |||
| "cancel_new_parameter", | |||
| "set_renorm" | |||
| "set_renorm", | |||
| "set_reopt" | |||
| ] | |||
| @@ -23,7 +23,8 @@ __all__ = [ | |||
| "unregiste_pass", | |||
| "gen_new_parameter", | |||
| "cancel_new_parameter", | |||
| "set_renorm" | |||
| "set_renorm", | |||
| "set_reopt" | |||
| ] | |||
| class PyPassManager(PyPassManager_): | |||
| r""" | |||
| @@ -75,6 +76,11 @@ class PyPassManager(PyPassManager_): | |||
| raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}") | |||
| super().set_renorm(should_renorm) | |||
| def set_reopt(self, do_reopt): | |||
| if not isinstance(do_reopt, bool): | |||
| raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}") | |||
| super().set_reopt(do_reopt) | |||
| def registe_pass(run_only_once=False): | |||
| """ | |||
| Registe python pass to specified pipeline phase which would be used in compilation. | |||
| @@ -164,3 +170,17 @@ def set_renorm(should_renorm): | |||
| """ | |||
| ppm = PyPassManager() | |||
| ppm.set_renorm(should_renorm) | |||
| def set_reopt(do_reopt): | |||
| """ | |||
| Set whether or not to do optimization after modified graph in python pass(es). | |||
| Args: | |||
| do_reopt(bool): whether or not to do optimization after modified graph in python pass(es). | |||
| NOTE: | |||
| This interface is mainly intended for testing modifying graph without worrying about its validity. Turn off | |||
| renormalization may BREAK the network. | |||
| """ | |||
| ppm = PyPassManager() | |||
| ppm.set_reopt(do_reopt) | |||
| @@ -20,7 +20,7 @@ from mindspore import context | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\ | |||
| cancel_new_parameter | |||
| cancel_new_parameter, set_reopt | |||
| from mindspore.common.api import _generate_pip_args | |||
| from mindspore._c_expression import generate_key, Executor_ | |||
| from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm | |||
| @@ -50,8 +50,8 @@ def test_softmax_relu(): | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_relu_pass(): | |||
| x = Any() | |||
| pattern = Call(P.Softmax(), inputs=[x]) | |||
| target = Call(P.ReLU(), inputs=[x]) | |||
| pattern = Call(P.Softmax(), [x]) | |||
| target = Call(P.ReLU(), [x]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | |||
| @@ -59,6 +59,23 @@ def test_softmax_relu(): | |||
| assert "ReLU" in transformed_repr | |||
| assert "Softmax" not in transformed_repr | |||
| def test_prim(): | |||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||
| softmax_model = nn.Softmax() | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_relu_pass(): | |||
| x = Any() | |||
| sigmoid_softmax_pattern = Prim([P.Sigmoid(), P.Softmax()]) | |||
| pattern = Call(sigmoid_softmax_pattern, [x]) | |||
| target = Call(P.ReLU(), [x]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3) | |||
| unregiste_pass(softmax_relu_pass) | |||
| assert "ReLU" in transformed_repr | |||
| assert "Softmax" not in transformed_repr | |||
| def test_softmax_relu_sigmoid(): | |||
| """ | |||
| Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)). | |||
| @@ -73,11 +90,11 @@ def test_softmax_relu_sigmoid(): | |||
| def softmax_relu_pass(): | |||
| x = Any() | |||
| softmax_pattern = Prim(P.Softmax()) | |||
| pattern = Call(softmax_pattern, inputs=[x]) | |||
| pattern = Call(softmax_pattern, [x]) | |||
| sigmoid_pattern = Prim(P.Sigmoid()) | |||
| call_sigmoid = Call(sigmoid_pattern, [x]) | |||
| relu_pattern = Prim(P.ReLU()) | |||
| target = Call(relu_pattern, inputs=[call_sigmoid]) | |||
| target = Call(relu_pattern, [call_sigmoid]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3) | |||
| @@ -98,13 +115,13 @@ def test_isin_pattern_0(): | |||
| def softmax_relu_pass(): | |||
| x = Any() | |||
| softmax_pattern = Prim(P.Softmax()) | |||
| call_softmax = Call(softmax_pattern, inputs=[x]) | |||
| call_softmax = Call(softmax_pattern, [x]) | |||
| relu_pattern = Prim(P.ReLU()) | |||
| call_relu = Call(relu_pattern, inputs=[x]) | |||
| call_relu = Call(relu_pattern, [x]) | |||
| pattern = OneOf([call_softmax, call_relu]) | |||
| relu6_pattern = Prim(P.ReLU6()) | |||
| target = Call(relu6_pattern, inputs=[x]) | |||
| target = Call(relu6_pattern, [x]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | |||
| unregiste_pass(softmax_relu_pass) | |||
| @@ -122,13 +139,13 @@ def test_isin_pattern_1(): | |||
| def softmax_neg_pass(): | |||
| x = Any() | |||
| softmax_pattern = Prim(P.Softmax()) | |||
| call_softmax = Call(softmax_pattern, inputs=[x]) | |||
| call_softmax = Call(softmax_pattern, [x]) | |||
| relu_pattern = Prim(P.ReLU()) | |||
| call_relu = Call(relu_pattern, inputs=[x]) | |||
| call_relu = Call(relu_pattern, [x]) | |||
| pattern = OneOf([call_softmax, call_relu]) | |||
| neg_ops = Prim(P.Neg()) | |||
| target = Call(neg_ops, inputs=[pattern]) | |||
| target = Call(neg_ops, [pattern]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4) | |||
| unregiste_pass(softmax_neg_pass) | |||
| @@ -141,6 +158,7 @@ def test_isnot_pattern_0(): | |||
| Case: IsNot pass failed to match | |||
| """ | |||
| set_renorm(False) | |||
| set_reopt(False) | |||
| class ConvBN(nn.Cell): | |||
| def __init__(self): | |||
| super(ConvBN, self).__init__() | |||
| @@ -166,8 +184,8 @@ def test_isnot_pattern_0(): | |||
| conv2d_prim = Prim("Conv2D") | |||
| conv2d = Call(conv2d_prim) | |||
| pattern_0 = NoneOf(conv2d) | |||
| pattern = Call(P.BatchNorm(), inputs=[pattern_0]) | |||
| target = Call(P.ReLU6(), inputs=[pattern_0]) | |||
| pattern = Call(P.BatchNorm(), [pattern_0]) | |||
| target = Call(P.ReLU6(), [pattern_0]) | |||
| return pattern, target | |||
| @registe_pass(run_only_once=True) | |||
| @@ -202,9 +220,9 @@ def test_isnot_pattern_1(): | |||
| matmul = Prim("MatMul") | |||
| pattern_0 = NoneOf(matmul) | |||
| softmax = P.Softmax() | |||
| pattern = Call(softmax, inputs=[pattern_0]) | |||
| pattern = Call(softmax, [pattern_0]) | |||
| relu6 = P.ReLU6() | |||
| target = Call(relu6, inputs=[pattern_0]) | |||
| target = Call(relu6, [pattern_0]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | |||
| @@ -217,17 +235,18 @@ def test_newtensor_pattern(): | |||
| Test NewTensor pattern in the target | |||
| """ | |||
| set_renorm(False) | |||
| set_reopt(False) | |||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||
| softmax_model = nn.Softmax() | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_addn_pass(): | |||
| x = Any() | |||
| pattern = Call(P.Softmax(), inputs=[x]) | |||
| pattern = Call(P.Softmax(), [x]) | |||
| weight_tensor = Tensor(np.zeros([42]), mindspore.float16) | |||
| new_weight = NewTensor(weight_tensor) | |||
| target = Call(P.AddN(), inputs=[x, new_weight]) | |||
| target = Call(P.AddN(), [x, new_weight]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | |||
| unregiste_pass(softmax_addn_pass) | |||
| @@ -242,17 +261,19 @@ def test_newparameter_pattern(): | |||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||
| softmax_model = nn.Softmax() | |||
| set_renorm(False) | |||
| set_reopt(False) | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_addn_pass(): | |||
| x = Any() | |||
| pattern = Call(P.Softmax(), inputs=[x]) | |||
| pattern = Call(P.Softmax(), [x]) | |||
| default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32) | |||
| default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32) | |||
| new_para_0 = NewParameter("Merlin", default_tensor0) | |||
| new_para_1 = NewParameter("Arthur", default_tensor1) | |||
| target_0 = Call(P.MatMul(), inputs=[new_para_0, new_para_1]) | |||
| target = Call("make_tuple", inputs=[target_0]) | |||
| target_0 = Call(P.MatMul(), [new_para_0, new_para_1]) | |||
| target = Call("make_tuple", [target_0]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | |||
| unregiste_pass(softmax_addn_pass) | |||
| @@ -267,13 +288,15 @@ def test_imm_target(): | |||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||
| softmax_model = nn.Softmax() | |||
| set_renorm(False) | |||
| set_reopt(False) | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_pass(): | |||
| x = Any() | |||
| pattern = Call(P.Softmax(), inputs=[x]) | |||
| pattern = Call(P.Softmax(), [x]) | |||
| imm = Imm(0) | |||
| target_0 = Call("make_tuple", inputs=[pattern]) | |||
| target = Call("tuple_getitem", inputs=[target_0, imm]) | |||
| target_0 = Call("make_tuple", [pattern]) | |||
| target = Call("tuple_getitem", [target_0, imm]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | |||
| unregiste_pass(softmax_pass) | |||
| @@ -290,14 +313,16 @@ def test_gen_new_parameter(): | |||
| default_tensor = Tensor(np.ones((4, 4)), mindspore.float32) | |||
| new_para = NewParameter("Merlin", default_tensor) | |||
| set_renorm(False) | |||
| set_reopt(False) | |||
| gen_new_parameter(new_para) | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_make_tuple_pass(): | |||
| x = Any() | |||
| softmax = P.Softmax() | |||
| pattern = Call(softmax, inputs=[x]) | |||
| pattern = Call(softmax, [x]) | |||
| target = Call("make_tuple", inputs=[pattern, new_para]) | |||
| target = Call("make_tuple", [pattern, new_para]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | |||
| assert "Merlin" in transformed_repr | |||