| @@ -14,6 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "frontend/optimizer/pass_group.h" | |||
| #include "frontend/optimizer/py_pass_manager.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -35,14 +36,15 @@ bool PassGroup::DeletePass(const std::string &pass_name) { | |||
| return false; | |||
| } | |||
| bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes) const { | |||
| bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes, | |||
| const MatchResultPtr &res) const { | |||
| if (func_graph == nullptr) { | |||
| return false; | |||
| } | |||
| bool changed = false; | |||
| for (const auto &pass : passes) { | |||
| if (pass != nullptr) { | |||
| if (pass->Run(func_graph)) { | |||
| if (pass->Run(func_graph, res)) { | |||
| changed = true; | |||
| } | |||
| } | |||
| @@ -54,8 +56,9 @@ bool PassGroup::Run(const FuncGraphPtr &func_graph) const { | |||
| bool changed = false; | |||
| // run all passes | |||
| bool change = true; | |||
| auto res = PyPassManager::GetInstance()->GetMatchResult(); | |||
| while (change) { | |||
| change = Run(func_graph, passes_); | |||
| change = Run(func_graph, passes_, res); | |||
| changed = change || changed; | |||
| if (run_only_once_) { | |||
| break; | |||
| @@ -41,12 +41,14 @@ class PassGroup { | |||
| // @return false, graph not changed | |||
| bool Run(const FuncGraphPtr &func_graph) const; | |||
| // Run the given graph passes on the input graph | |||
| // @param [inout] graph The graph to be optimized | |||
| // @param [inout] func_graph The graph to be optimized | |||
| // @param [in] passes The given graph passes | |||
| // @param [inout] res MatchResult used to collect all matched patterns and nodes | |||
| // @return true, graph changed | |||
| // @return false, graph not changed | |||
| bool Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes) const; | |||
| bool Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes, const MatchResultPtr &res) const; | |||
| std::string name() const { return name_; } | |||
| void SetRunOnlyOnce(bool run_only_once) { run_only_once_ = run_only_once; } | |||
| private: | |||
| const std::string name_; | |||
| @@ -96,6 +96,7 @@ MatchResultPtr IsIn::match(const AnfNodePtr &node) { | |||
| for (auto &iter : patterns_) { | |||
| auto res = iter->match(node); | |||
| if (res != nullptr) { | |||
| res->add_entry(shared_from_base<IsIn>(), node); | |||
| return res; | |||
| } | |||
| } | |||
| @@ -151,6 +152,9 @@ REGISTER_PYBIND_DEFINE( | |||
| (void)py::class_<AnyPattern, std::shared_ptr<AnyPattern>, Pattern>(*m, "AnyPattern").def(py::init<>()); | |||
| (void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_") | |||
| .def(py::init<tensor::TensorPtr>()); | |||
| (void)py::class_<NewParameter, std::shared_ptr<NewParameter>, Pattern>(*m, "NewParameter_") | |||
| .def(py::init<string, tensor::TensorPtr, bool, bool, bool>()); | |||
| (void)py::class_<Imm, std::shared_ptr<Imm>, Pattern>(*m, "Imm").def(py::init<int>()); | |||
| })); | |||
| } // namespace python_pass | |||
| } // namespace opt | |||
| @@ -42,6 +42,10 @@ class CallWith; | |||
| using CallWithPtr = std::shared_ptr<CallWith>; | |||
| class NewTensor; | |||
| using NewTensorPtr = std::shared_ptr<NewTensor>; | |||
| class NewParameter; | |||
| using NewParameterPtr = std::shared_ptr<NewParameter>; | |||
| class Imm; | |||
| using ImmPtr = std::shared_ptr<Imm>; | |||
| struct PatternHasher; | |||
| struct PatternEqual; | |||
| using PatternNodeMap = std::unordered_map<PatternPtr, AnfNodePtr, PatternHasher, PatternEqual>; | |||
| @@ -55,6 +59,7 @@ class Pattern : public Base { | |||
| string unique_name() const { return unique_name_; } | |||
| vector<PatternPtr> inputs() { return inputs_; } | |||
| bool should_replace() { return should_replace_; } | |||
| void set_should_replace(bool should_replace) { should_replace_ = should_replace; } | |||
| virtual void reset() {} | |||
| protected: | |||
| @@ -86,14 +91,14 @@ class IsPrimTypeOf : public Pattern { | |||
| ~IsPrimTypeOf() = default; | |||
| IsPrimTypeOf(vector<PrimitivePyPtr> prims, string name, bool should_replace) | |||
| : primitives_(prims), name_(name), matched_prim_(nullptr) { | |||
| unique_name_ = std::to_string(g_id_++) + "_" + name; | |||
| unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name; | |||
| should_replace_ = should_replace; | |||
| if (!should_replace) { | |||
| matched_prim_ = prims[0]; | |||
| } | |||
| } | |||
| IsPrimTypeOf(vector<string> types, string name, bool should_replace) : types_(types), name_(name) { | |||
| unique_name_ = std::to_string(g_id_++) + "_" + name; | |||
| unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name; | |||
| // Make primitives_ | |||
| for (auto &iter : types) { | |||
| primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr))); | |||
| @@ -126,19 +131,20 @@ class CallWith : public Pattern { | |||
| CallWith(PatternPtr prim_pattern, vector<PatternPtr> inputs, bool should_replace) { | |||
| // NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting | |||
| prim_pattern_ = prim_pattern; | |||
| unique_name_ = std::to_string(g_id_++) + prim_pattern->unique_name(); | |||
| unique_name_ = std::to_string(g_id_++) + "CallWithPattern_" + prim_pattern->unique_name(); | |||
| inputs_ = inputs; | |||
| should_replace_ = should_replace; | |||
| // NOTE: should_replace_ is overrided by it prim_pattern(if exists) silently. | |||
| should_replace_ = prim_pattern->should_replace(); | |||
| } | |||
| CallWith(PrimitivePyPtr prim, vector<PatternPtr> inputs, bool should_replace) { | |||
| prim_ = prim; | |||
| unique_name_ = std::to_string(g_id_++) + prim_->ToString(); | |||
| unique_name_ = std::to_string(g_id_++) + "CallWithPrim_" + prim_->ToString(); | |||
| inputs_ = inputs; | |||
| should_replace_ = should_replace; | |||
| } | |||
| CallWith(string prim_str, vector<PatternPtr> inputs, bool should_replace) { | |||
| prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr)); | |||
| unique_name_ = std::to_string(g_id_++) + prim_->ToString(); | |||
| unique_name_ = std::to_string(g_id_++) + "CallWithStr_" + prim_->ToString(); | |||
| inputs_ = inputs; | |||
| should_replace_ = should_replace; | |||
| } | |||
| @@ -159,7 +165,7 @@ class IsIn : public Pattern { | |||
| IsIn() { unique_name_ = std::to_string(g_id_++); } | |||
| ~IsIn() = default; | |||
| explicit IsIn(vector<PatternPtr> patterns) : patterns_(patterns) { | |||
| unique_name_ = std::to_string(g_id_++); | |||
| unique_name_ = std::to_string(g_id_++) + "IsIn"; | |||
| for (auto &iter : patterns) { | |||
| unique_name_ = unique_name_ + "_" + iter->unique_name(); | |||
| } | |||
| @@ -176,9 +182,9 @@ class IsNot : public Pattern { | |||
| IsNot() { unique_name_ = std::to_string(g_id_++); } | |||
| ~IsNot() = default; | |||
| explicit IsNot(vector<PatternPtr> patterns) : patterns_(patterns) { | |||
| unique_name_ = std::to_string(g_id_++); | |||
| unique_name_ = std::to_string(g_id_++) + "IsNot"; | |||
| for (auto &iter : patterns) { | |||
| unique_name_ = "IsNot_" + unique_name_ + "_" + iter->unique_name(); | |||
| unique_name_ = unique_name_ + "_" + iter->unique_name(); | |||
| } | |||
| } | |||
| MS_DECLARE_PARENT(IsNot, Pattern); | |||
| @@ -200,7 +206,10 @@ class NewTensor : public Pattern { | |||
| public: | |||
| NewTensor() { unique_name_ = std::to_string(g_id_++); } | |||
| ~NewTensor() = default; | |||
| explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { should_replace_ = false; } | |||
| explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { | |||
| should_replace_ = false; | |||
| unique_name_ = std::to_string(g_id_++) + "NewTensor"; | |||
| } | |||
| MS_DECLARE_PARENT(NewTensor, Pattern); | |||
| MatchResultPtr match(const AnfNodePtr &node) override { | |||
| MS_LOG(EXCEPTION) << "Find NewTensor in pattern, NewTensor should only appear in the target.\n"; | |||
| @@ -211,6 +220,54 @@ class NewTensor : public Pattern { | |||
| tensor::TensorPtr input_tensor_; | |||
| }; | |||
| class NewParameter : public Pattern { | |||
| public: | |||
| NewParameter() { unique_name_ = std::to_string(g_id_++); } | |||
| explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel, | |||
| bool should_replace) | |||
| : para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) { | |||
| should_replace_ = should_replace; | |||
| unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name; | |||
| // clone input tensor | |||
| default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get()); | |||
| built_ = false; | |||
| } | |||
| MS_DECLARE_PARENT(NewParameter, Pattern); | |||
| MatchResultPtr match(const AnfNodePtr &node) override { | |||
| MS_LOG(EXCEPTION) << "Find NewParameter in pattern, NewParameter should only appear in the target.\n"; | |||
| } | |||
| string para_name() { return para_name_; } | |||
| tensor::TensorPtr default_tensor() { return default_tensor_; } | |||
| bool requires_grad() { return requires_grad_; } | |||
| bool layerwise_parallel() { return layerwise_parallel_; } | |||
| bool built() { return built_; } | |||
| void set_built(bool built) { built_ = built; } | |||
| void reset() override { built_ = false; } | |||
| private: | |||
| string para_name_; | |||
| bool requires_grad_; | |||
| bool layerwise_parallel_; | |||
| bool built_; | |||
| tensor::TensorPtr default_tensor_; | |||
| }; | |||
| class Imm : public Pattern { | |||
| public: | |||
| Imm() { unique_name_ = std::to_string(g_id_++); } | |||
| explicit Imm(int value) : value_(value) { | |||
| should_replace_ = false; | |||
| unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value); | |||
| } | |||
| MS_DECLARE_PARENT(Imm, Pattern); | |||
| // NOTE: Doesn't support Imm in src pattern currently. | |||
| MatchResultPtr match(const AnfNodePtr &node) override { return nullptr; } | |||
| int value() { return value_; } | |||
| private: | |||
| int value_; | |||
| }; | |||
| class MatchResult { | |||
| public: | |||
| MatchResult() {} | |||
| @@ -21,13 +21,26 @@ | |||
| #include "ir/func_graph.h" | |||
| #include "ir/manager.h" | |||
| #include "pybind_api/ir/primitive_py.h" | |||
| #include "ir/scalar.h" | |||
| #include "ir/graph_utils.h" | |||
| #include "pipeline/jit/parse/parse_base.h" | |||
| #include "pipeline/jit/resource.h" | |||
| #include "frontend/optimizer/py_pass_manager.h" | |||
| #include "utils/info.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "debug/draw.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace python_pass { | |||
| namespace internal { | |||
| AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res); | |||
| const char PARAMETER_MODULE[] = "mindspore.common.parameter"; | |||
| const char PARAMETER_CLASS[] = "Parameter"; | |||
| const char SET_PARAM[] = "__setattr__"; | |||
| AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph); | |||
| AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res); | |||
| void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input, | |||
| bool requires_grad, bool layerwise_parallel); | |||
| std::string GetNodeRepr(AnfNodePtr node) { | |||
| if (node != nullptr) { | |||
| @@ -42,8 +55,10 @@ std::string GetNodeRepr(AnfNodePtr node) { | |||
| repr += ")"; | |||
| return repr; | |||
| } | |||
| if (node->isa<ValueNode>()) { | |||
| return GetValueNode(node)->ToString(); | |||
| if (node->isa<Parameter>()) { | |||
| return "[Parameter]" + node->ToString(); | |||
| } else if (node->isa<ValueNode>()) { | |||
| return "[Value]" + GetValueNode(node)->ToString(); | |||
| } | |||
| return node->ToString(); | |||
| } | |||
| @@ -82,7 +97,7 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res) | |||
| return std::make_shared<ValueNode>(input_tensor); | |||
| } | |||
| AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res) { | |||
| AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg) { | |||
| auto call_with_pattern = pattern->cast<CallWithPtr>(); | |||
| MS_EXCEPTION_IF_NULL(call_with_pattern); | |||
| auto prim = call_with_pattern->prim_value(); | |||
| @@ -91,15 +106,70 @@ AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultP | |||
| } | |||
| auto prim_pattern = call_with_pattern->prim_pattern(); | |||
| MS_EXCEPTION_IF_NULL(prim_pattern); | |||
| return ProcessSinglePattern(prim_pattern, res); | |||
| return ProcessSinglePattern(prim_pattern, res, fg); | |||
| } | |||
| AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res) { | |||
| AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) { | |||
| auto new_para_pattern = pattern->cast<NewParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(new_para_pattern); | |||
| if (!new_para_pattern->built()) { | |||
| static int parameter_id = 0; | |||
| auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name() + std::to_string(parameter_id++); | |||
| auto para_node = std::make_shared<Parameter>(func_graph); | |||
| MS_EXCEPTION_IF_NULL(para_node); | |||
| para_node->set_name(para_name); | |||
| // Set function graph | |||
| para_node->set_func_graph(func_graph); | |||
| // Set Debug Info | |||
| auto debug_info = std::make_shared<NodeDebugInfo>(para_name); | |||
| para_node->set_debug_info(debug_info); | |||
| // Set abstract | |||
| auto default_value = new_para_pattern->default_tensor(); | |||
| MS_EXCEPTION_IF_NULL(default_value); | |||
| para_node->set_abstract(default_value->ToAbstract()->Broaden()); | |||
| res->add_entry(pattern, para_node); | |||
| func_graph->add_parameter(para_node); | |||
| // Reflect back to Cell._params | |||
| internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(), | |||
| new_para_pattern->layerwise_parallel()); | |||
| MS_LOG(WARNING) << "Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name(); | |||
| new_para_pattern->set_built(true); | |||
| return para_node; | |||
| } else { | |||
| // Built, fetch the node | |||
| auto para_node = res->get_node(pattern); | |||
| MS_EXCEPTION_IF_NULL(para_node); | |||
| return para_node; | |||
| } | |||
| } | |||
| AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) { | |||
| auto imm_pattern = pattern->cast<ImmPtr>(); | |||
| MS_EXCEPTION_IF_NULL(imm_pattern); | |||
| auto value = imm_pattern->value(); | |||
| auto scalar_value_ptr = std::make_shared<Int32Imm>(value); | |||
| return std::make_shared<ValueNode>(scalar_value_ptr); | |||
| } | |||
| AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) { | |||
| if (pattern->should_replace()) { | |||
| // Find replacement in the MatchResult | |||
| auto target_node = res->get_node(pattern); | |||
| if (target_node == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Cannot find target node in pattern match result, pattern: " + pattern->unique_name() + "\n"; | |||
| // If it's base pattern(in contrast to complex pattern like CallWith/IsIn/IsNot), raise runtime exception. | |||
| if (pattern->isa<IsPrimTypeOf>() || pattern->isa<NewTensor>() || pattern->isa<NewParameter>()) { | |||
| MS_LOG(EXCEPTION) << "Cannot find target node, pattern: " + pattern->unique_name() + "\n"; | |||
| return nullptr; | |||
| } | |||
| // Try to build this pattern and add to MatchResult, since this pattern is defined inside target | |||
| auto new_node = BuildTarget(pattern, func_graph, res); | |||
| if (new_node == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Try to build pattern node but FAILED. pattern: " + pattern->unique_name() + "\n"; | |||
| } | |||
| return new_node; | |||
| } | |||
| if (pattern->isa<NewParameter>()) { | |||
| return target_node; | |||
| } | |||
| return target_node; | |||
| } | |||
| @@ -109,7 +179,19 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr | |||
| } else if (pattern->isa<NewTensor>()) { | |||
| return BuildNewTensor(pattern, res); | |||
| } else if (pattern->isa<CallWith>()) { | |||
| return BuildPrimitiveValueNode(pattern, res); | |||
| return BuildPrimitiveValueNode(pattern, res, func_graph); | |||
| } else if (pattern->isa<NewParameter>()) { | |||
| return BuildNewParameter(pattern, res, func_graph); | |||
| } else if (pattern->isa<Imm>()) { | |||
| return BuildImmNode(pattern, res); | |||
| } | |||
| return nullptr; | |||
| } | |||
| AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res, | |||
| const FuncGraphPtr &func_graph) { | |||
| if (pattern->isa<CallWith>()) { | |||
| return BuildPrimitiveValueNode(pattern, res, func_graph); | |||
| } | |||
| return nullptr; | |||
| } | |||
| @@ -117,91 +199,154 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr | |||
| AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res) { | |||
| auto target_inputs = pattern->inputs(); | |||
| if (target_inputs.size() == 0) { | |||
| return ProcessSinglePattern(pattern, res); | |||
| auto new_node = ProcessSinglePattern(pattern, res, func_graph); | |||
| if (new_node != nullptr) { | |||
| res->add_entry(pattern, new_node); | |||
| } | |||
| return new_node; | |||
| } | |||
| // Build up the AnfNode in a recursive manner | |||
| std::vector<AnfNodePtr> new_inputs; | |||
| auto prim_value_node = ProcessSinglePattern(pattern, res); | |||
| auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph); | |||
| MS_EXCEPTION_IF_NULL(prim_value_node); | |||
| new_inputs.push_back(prim_value_node); | |||
| for (auto &iter : target_inputs) { | |||
| if (iter == pattern) { | |||
| MS_LOG(EXCEPTION) << "Circle references: Pattern takes itself as input. Got pattern: " + pattern->unique_name() + | |||
| "\n"; | |||
| MS_LOG(EXCEPTION) << "Circle references. Got pattern: " + pattern->unique_name() + "\n"; | |||
| } | |||
| auto input_node = BuildTarget(iter, func_graph, res); | |||
| if (input_node == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failed to build input node for pattern : " + iter->unique_name() + "\n"; | |||
| } | |||
| new_inputs.push_back(input_node); | |||
| } | |||
| auto new_node = func_graph->NewCNode(new_inputs); | |||
| res->add_entry(pattern, new_node); | |||
| return new_node; | |||
| } | |||
| void DrawNode(string name, AnfNodePtr node) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| auto new_func_graph = std::make_shared<FuncGraph>(); | |||
| new_func_graph->set_output(node, true); | |||
| if (save_graphs) { | |||
| auto ir_dump_path = save_graphs_path + "/" + name + ".ir"; | |||
| auto dot_dump_path = save_graphs_path + "/" + name + ".dot"; | |||
| DumpIR(ir_dump_path, new_func_graph); | |||
| draw::Draw(dot_dump_path, new_func_graph); | |||
| } | |||
| } | |||
| void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input, | |||
| bool requires_grad, bool layerwise_parallel) { | |||
| // 1. Get current cell object | |||
| auto ppm = opt::python_pass::PyPassManager::GetInstance(); | |||
| auto resource = ppm->GetResource(); | |||
| py::object top_cell = resource->input(); | |||
| if (py::isinstance<py::none>(top_cell)) { | |||
| MS_LOG(EXCEPTION) << "Failed to get top cell from resource."; | |||
| } | |||
| // 2. New a Parameter object with the above-specified args | |||
| py::object parameter_class = py::module::import(PARAMETER_MODULE).attr(PARAMETER_CLASS); | |||
| py::object new_parameter = parameter_class(default_input, param_name, requires_grad, layerwise_parallel); | |||
| // 3. Add the new python Parameter object to Cell's _params atttributes | |||
| top_cell.attr(SET_PARAM)(param_name, new_parameter); | |||
| // 4. Set default_param for param_node | |||
| ValuePtr param_value = nullptr; | |||
| bool converted = parse::ConvertData(new_parameter, ¶m_value, false); | |||
| if (!converted) { | |||
| MS_LOG(EXCEPTION) << "Failed to convert new parameter to ValuePtr."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(param); | |||
| auto param_node = param->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(param_node); | |||
| param_node->set_default_param(param_value); | |||
| } | |||
| void Reset(PatternPtr pattern) { | |||
| if (pattern->isa<IsPrimTypeOf>()) { | |||
| auto prim_pattern = pattern->cast<IsPrimTypeOfPtr>(); | |||
| prim_pattern->reset(); | |||
| return; | |||
| } else if (pattern->isa<NewParameter>()) { | |||
| auto new_param_pattern = pattern->cast<NewParameterPtr>(); | |||
| new_param_pattern->reset(); | |||
| return; | |||
| } else if (pattern->isa<CallWith>()) { | |||
| auto call_with_pattern = pattern->cast<CallWithPtr>(); | |||
| for (auto sub_pattern : call_with_pattern->inputs()) { | |||
| Reset(sub_pattern); | |||
| } | |||
| new_inputs.push_back(BuildTarget(iter, func_graph, res)); | |||
| return; | |||
| } | |||
| return func_graph->NewCNode(new_inputs); | |||
| return; | |||
| } | |||
| } // namespace internal | |||
| AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(src_pattern_); | |||
| MS_EXCEPTION_IF_NULL(dst_pattern_); | |||
| auto res = src_pattern_->match(node); | |||
| if (res != nullptr) { | |||
| res->dump(); | |||
| MS_LOG(WARNING) << "Matched pattern: " + src_pattern_->unique_name(); | |||
| AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) { | |||
| auto match_res = src_pattern_->match(node); | |||
| if (match_res != nullptr) { | |||
| MS_LOG(DEBUG) << "Matched pattern: " + src_pattern_->unique_name() + " node : " + internal::GetNodeRepr(node); | |||
| res->merge(match_res); | |||
| auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res); | |||
| dst_pattern_->reset(); | |||
| MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; | |||
| internal::Reset(dst_pattern()); | |||
| MS_LOG(WARNING) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; | |||
| return new_node; | |||
| } | |||
| src_pattern_->reset(); | |||
| internal::Reset(src_pattern()); | |||
| return nullptr; | |||
| } | |||
| bool PythonPass::Run(const FuncGraphPtr &func_graph) { | |||
| bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(dst_pattern_); | |||
| if (src_pattern_ == nullptr) { | |||
| // Add NewParameter | |||
| auto new_para_pattern = dst_pattern_->cast<NewParameterPtr>(); | |||
| if (new_para_pattern == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Expect NewParameter pattern for target if src pattern is null."; | |||
| } | |||
| auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name(); | |||
| MS_LOG(DEBUG) << "Adding New parameter : " + para_name; | |||
| auto para_node = std::make_shared<Parameter>(func_graph); | |||
| MS_EXCEPTION_IF_NULL(para_node); | |||
| para_node->set_name(para_name); | |||
| // Set function graph | |||
| para_node->set_func_graph(func_graph); | |||
| // Set Debug Info | |||
| auto debug_info = std::make_shared<NodeDebugInfo>(para_name); | |||
| para_node->set_debug_info(debug_info); | |||
| // Set abstract | |||
| auto default_value = new_para_pattern->default_tensor(); | |||
| MS_EXCEPTION_IF_NULL(default_value); | |||
| para_node->set_abstract(default_value->ToAbstract()->Broaden()); | |||
| res->add_entry(dst_pattern_, para_node); | |||
| func_graph->add_parameter(para_node); | |||
| // Reflect back to Cell._params | |||
| internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(), | |||
| new_para_pattern->layerwise_parallel()); | |||
| MS_LOG(WARNING) << "Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name(); | |||
| return true; | |||
| } | |||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->AddFuncGraph(func_graph); | |||
| auto seen = NewSeenGeneration(); | |||
| // 1024 is for the initial capacity of deque | |||
| std::deque<AnfNodePtr> todo(1024); | |||
| todo.push_back(func_graph->output()); | |||
| auto graph_nodes_sorted = TopoSort(func_graph->output()); | |||
| bool changes = false; | |||
| auto &all_nodes = manager->all_nodes(); | |||
| while (!todo.empty()) { | |||
| AnfNodePtr node = todo.front(); | |||
| todo.pop_front(); | |||
| // Check whether this node has been matched. | |||
| if (node == nullptr || node->seen_ == seen || !internal::IsTraversable(node) || !all_nodes.contains(node)) { | |||
| continue; | |||
| } | |||
| node->seen_ = seen; | |||
| // Select nodes that this transform can be applied. | |||
| AnfNodePtr new_node = Run(func_graph, node); | |||
| bool change = (new_node != nullptr); | |||
| // Traverse once | |||
| for (auto &node : graph_nodes_sorted) { | |||
| AnfNodePtr new_node = Run(func_graph, node, res); | |||
| if (new_node != nullptr && new_node != node) { | |||
| internal::DrawNode(dst_pattern_->unique_name(), new_node); | |||
| (void)manager->Replace(node, new_node); | |||
| } else if (new_node == nullptr) { | |||
| new_node = node; | |||
| } | |||
| if (run_only_once_) { | |||
| return change; | |||
| } | |||
| // Find success, and add them to todo list | |||
| if (IsValueNode<FuncGraph>(node)) { | |||
| todo.push_back(GetValueNode<FuncGraphPtr>(node)->output()); | |||
| } | |||
| if (node->isa<CNode>()) { | |||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||
| (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); | |||
| } | |||
| auto &node_users = manager->node_users(); | |||
| if (change && node_users.find(node) != node_users.end()) { | |||
| for (auto &use : node_users[node]) { | |||
| auto use_node = use.first; | |||
| if (use_node == nullptr) { | |||
| continue; | |||
| } | |||
| todo.push_back(use_node); | |||
| if (use_node->seen_ == seen) { | |||
| use_node->seen_--; | |||
| } | |||
| } | |||
| changes = true; | |||
| } | |||
| } | |||
| return changes; | |||
| @@ -34,20 +34,20 @@ using NodeEquivPtr = std::shared_ptr<NodeEquiv>; | |||
| class PythonPass { | |||
| public: | |||
| explicit PythonPass(const std::string &name, const PatternPtr &src, const PatternPtr &dst, bool run_only_once = false, | |||
| bool multigraph = true) | |||
| : src_pattern_(src), dst_pattern_(dst), name_(name), run_only_once_(run_only_once), multigraph_(multigraph) {} | |||
| explicit PythonPass(const std::string &name, const PatternPtr &src, const PatternPtr &dst, bool run_only_once = false) | |||
| : src_pattern_(src), dst_pattern_(dst), name_(name), run_only_once_(run_only_once) {} | |||
| ~PythonPass() = default; | |||
| bool Run(const FuncGraphPtr &func_graph); | |||
| bool Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res); | |||
| std::string name() const { return name_; } | |||
| AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node); | |||
| AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res); | |||
| PatternPtr src_pattern() { return src_pattern_; } | |||
| PatternPtr dst_pattern() { return dst_pattern_; } | |||
| private: | |||
| PatternPtr src_pattern_; | |||
| PatternPtr dst_pattern_; | |||
| const std::string name_; | |||
| bool run_only_once_; | |||
| bool multigraph_ = true; | |||
| }; | |||
| using PythonPassPtr = std::shared_ptr<PythonPass>; | |||
| @@ -45,14 +45,19 @@ PyPassManagerPtr PyPassManager::GetInstance() { | |||
| PyPassManager::PyPassManager() { | |||
| phase_to_group_[Phase::RESOLVE] = std::make_shared<PassGroup>(); | |||
| phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>(); | |||
| res_ = std::make_shared<MatchResult>(); | |||
| } | |||
| void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, | |||
| Phase phase, bool run_only_once, bool multigraph) { | |||
| auto cur_pm = GetPassGroup(phase); | |||
| MS_EXCEPTION_IF_NULL(cur_pm); | |||
| PythonPassPtr new_pass = std::make_shared<PythonPass>(pass_name, pattern, target, run_only_once, multigraph); | |||
| cur_pm->AddPass(new_pass); | |||
| Phase phase, bool run_only_once) { | |||
| auto cur_pg = GetPassGroup(phase); | |||
| MS_EXCEPTION_IF_NULL(cur_pg); | |||
| cur_pg->SetRunOnlyOnce(run_only_once); | |||
| MS_EXCEPTION_IF_NULL(pattern); | |||
| MS_EXCEPTION_IF_NULL(target); | |||
| MS_EXCEPTION_IF_NULL(cur_pg); | |||
| PythonPassPtr new_pass = std::make_shared<PythonPass>(pass_name, pattern, target, run_only_once); | |||
| cur_pg->AddPass(new_pass); | |||
| } | |||
| void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) { | |||
| @@ -63,6 +68,21 @@ void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) { | |||
| } | |||
| } | |||
| void PyPassManager::GenNewParameter(const PatternPtr ¶meter) { | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| // Add new parameter after resolve | |||
| // NOTE: Add NewParameter at early stage will cause CSE problems | |||
| auto cur_pg = GetPassGroup(Phase::OPT); | |||
| MS_EXCEPTION_IF_NULL(cur_pg); | |||
| cur_pg->SetRunOnlyOnce(true); | |||
| auto new_para_pattern = parameter->cast<NewParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(new_para_pattern); | |||
| auto pass_name = new_para_pattern->para_name(); | |||
| parameter->set_should_replace(false); | |||
| auto new_pass = std::make_shared<PythonPass>(pass_name, nullptr, parameter, true); | |||
| cur_pg->AddPass(new_pass); | |||
| } | |||
| void PyPassManager::ClearRes() { | |||
| MS_LOG(INFO) << "Clear PyPassManager resources!"; | |||
| global_instance = nullptr; | |||
| @@ -75,7 +95,9 @@ REGISTER_PYBIND_DEFINE( | |||
| (void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_") | |||
| .def(py::init([]() { return PyPassManager::GetInstance(); })) | |||
| .def("registe", &PyPassManager::Registe, "Registe python pass") | |||
| .def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass"); | |||
| .def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass") | |||
| .def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter") | |||
| .def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph"); | |||
| })); | |||
| } // namespace python_pass | |||
| } // namespace opt | |||
| @@ -27,7 +27,7 @@ | |||
| #include "ir/graph_utils.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "pipeline/jit/parse/resolve.h" | |||
| #include "pipeline/jit/resource.h" | |||
| #include "frontend/optimizer/pattern.h" | |||
| #include "frontend/optimizer/py_pass.h" | |||
| #include "frontend/optimizer/pass_group.h" | |||
| @@ -53,12 +53,21 @@ class PyPassManager { | |||
| static PyPassManagerPtr GetInstance(); | |||
| virtual ~PyPassManager() = default; | |||
| void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, | |||
| Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true); | |||
| Phase phase = Phase::RESOLVE, bool run_only_once = false); | |||
| void Unregiste(const std::string &pass_name, Phase phase); | |||
| void GenNewParameter(const PatternPtr ¶meter); | |||
| PassGroupPtr GetPassGroup(Phase phase); | |||
| void ClearRes(); | |||
| MatchResultPtr GetMatchResult() { return res_; } | |||
| void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; } | |||
| bool ShouldRenorm() { return should_renorm_; } | |||
| void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; } | |||
| pipeline::ResourcePtr GetResource() { return resource_; } | |||
| private: | |||
| bool should_renorm_ = true; | |||
| MatchResultPtr res_; | |||
| pipeline::ResourcePtr resource_; | |||
| static std::unordered_map<Phase, PassGroupPtr> phase_to_group_; | |||
| }; | |||
| } // namespace python_pass | |||
| @@ -448,8 +448,21 @@ void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) { | |||
| MS_EXCEPTION_IF_NULL(res->manager()); | |||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||
| auto ppm = opt::python_pass::PyPassManager::GetInstance(); | |||
| ppm->SetResource(res); | |||
| if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) { | |||
| MS_LOG(DEBUG) << "No match.\n"; | |||
| } else if (phase == opt::python_pass::Phase::OPT && opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) { | |||
| MS_LOG(DEBUG) << "Entered PyStub Renorm"; | |||
| // Renomalize | |||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| abstract::AbstractBasePtrList args_spec; | |||
| auto parameters = func_graph->parameters(); | |||
| (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), | |||
| [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); | |||
| FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); | |||
| res->set_func_graph(new_fg); | |||
| res->set_args_spec(args_spec); | |||
| } | |||
| } | |||
| @@ -477,6 +490,7 @@ static std::vector<ActionItem> CommonPipeline() { | |||
| } | |||
| // Add resolve-stage python pass stub | |||
| actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub)); | |||
| actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); | |||
| // Evaluate type and shape, and specialize | |||
| actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); | |||
| @@ -1,81 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Python pass register""" | |||
| from inspect import isfunction | |||
| from mindspore.common.graph_pattern import Pattern | |||
| from mindspore._c_expression import PyPassManager_ | |||
| from mindspore._c_expression import phase | |||
| class PyPassManager(PyPassManager_): | |||
| r""" | |||
| Used to registe and unregiste python passes which can be used to alter graphs. | |||
| Args: | |||
| pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt. | |||
| run_only_once (bool): Specify whether or not to run pass only once. Default: False. | |||
| multigraph (bool): Whether or not the pattern exists across graphs. Default: True. | |||
| Raises: | |||
| TypeError: If argument has invalid type. | |||
| """ | |||
| def __init__(self, pipeline_phase=phase.opt, run_only_once=False, multi_graph=True): | |||
| if not isinstance(pipeline_phase, phase): | |||
| raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}") | |||
| if not isinstance(run_only_once, bool): | |||
| raise TypeError(f"Expecting bool, got : ({type(run_only_once)}){run_only_once}") | |||
| if not isinstance(multi_graph, bool): | |||
| raise TypeError(f"Expecting bool, got : ({type(multi_graph)}){multi_graph}") | |||
| PyPassManager_.__init__(self) | |||
| self.phase_ = pipeline_phase | |||
| self.run_only_once_ = run_only_once | |||
| self.multi_graph_ = multi_graph | |||
| def registe(self, py_pass): | |||
| if not isfunction(py_pass): | |||
| raise TypeError(f"Expecting function pass, got : ({type(py_pass)}){py_pass}") | |||
| pattern, target = py_pass() | |||
| pass_name = py_pass.__name__ | |||
| if not isinstance(pattern, Pattern): | |||
| raise TypeError(f"Expecting pattern of Pattern type, got : ({type(pattern)}){pattern}") | |||
| if not isinstance(target, Pattern): | |||
| raise TypeError(f"Expecting target of Pattern type, got : ({type(target)}){target}") | |||
| super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_, self.multi_graph_) | |||
| def unregiste(self, py_pass, pipeline_phase=phase.opt): | |||
| if not isinstance(pipeline_phase, phase): | |||
| raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}") | |||
| if isinstance(py_pass, str): | |||
| super().unregiste(py_pass, pipeline_phase) | |||
| return | |||
| if isfunction(py_pass): | |||
| super().unregiste(py_pass.__name__, pipeline_phase) | |||
| return | |||
| raise TypeError(f"Expecting py_pass to be string or function, got ({type(py_pass)}){py_pass}") | |||
| def __call__(self, py_pass): | |||
| self.registe(py_pass) | |||
| return py_pass | |||
| def registe_pass(pipeline_phase=phase.opt, run_only_once=False, multi_graph=True): | |||
| """ | |||
| Examples: | |||
| >>> @registe_pass() | |||
| >>> def toy_pass(): | |||
| >>> def pattern(): | |||
| >>> pass | |||
| >>> def target(): | |||
| >>> pass | |||
| """ | |||
| return PyPassManager(pipeline_phase, run_only_once, multi_graph) | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Top-level reference to python pass.""" | |||
| @@ -15,7 +15,8 @@ | |||
| """Patterns for describing graphs""" | |||
| from mindspore.ops import Primitive | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_ | |||
| from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_,\ | |||
| NewParameter_, Imm | |||
| __all__ = [ | |||
| "IsIn", | |||
| @@ -24,17 +25,25 @@ __all__ = [ | |||
| "IsNot", | |||
| "AnyPattern", | |||
| "NewTensor", | |||
| "NewParameter", | |||
| "Imm" | |||
| ] | |||
| class IsIn(IsIn_): | |||
| """ | |||
| r""" | |||
| Express a pattern which allows a list of patterns. | |||
| """ | |||
| def __init__(self, patterns=None, should_replace=True): | |||
| r""" | |||
| Args: | |||
| patterns(list/tuple): list of allowed patterns | |||
| patterns(Union[tuple[:class:`mindspore.graph_utils.graph_pattern`], | |||
| list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns, | |||
| each element should be one of the exposed Pattern instance. | |||
| should_replace(bool): added this for interface consistency. Should only set this in sub-patterns. | |||
| Raises: | |||
| ValueError: raise if should_replace is False | |||
| TypeError: raise type error for invalid inputs. | |||
| """ | |||
| if not should_replace: | |||
| raise ValueError("IsIn pattern does not have its own should_replace attribute. Set should_replace in \ | |||
| @@ -52,19 +61,28 @@ class IsIn(IsIn_): | |||
| class IsPrimTypeOf(IsPrimTypeOf_): | |||
| r""" | |||
| Express a pattern of certain primitive type(s). | |||
| NOTE: This pattern will match and only match the primitive value node. If matching primitive CNode is needed, | |||
| please refer to CallWith pattern. | |||
| NOTE: | |||
| This pattern will match and only match the primitive value node. If matching primitive CNode is needed, | |||
| please refer to CallWith pattern. | |||
| """ | |||
| def __init__(self, types, name=None, should_replace=True): | |||
| r""" | |||
| Args: | |||
| types (str/(list/tuple of Primitives)): Specify allowed types. | |||
| types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`], | |||
| tuple[:class:`mindspore.ops.Primitive`]): | |||
| Specify allowed types. | |||
| If it is a string, the form could be | |||
| 1) a single primitive type, e.g. 'Conv2D' | |||
| 2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D' | |||
| It can also be a list of Primitives, e.g. [ops.Conv2D(1, 6)] | |||
| name (str): name of the pattern, optional | |||
| should_replace | |||
| It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)] | |||
| name (str): name of the pattern, optional. Default: None. | |||
| should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is | |||
| used when building the replacement target node. Use captured node if True, build from scratch otherwise. | |||
| Default: True. | |||
| Raises: | |||
| TypeError: raise type error for invalid argument. | |||
| """ | |||
| if name is not None and not isinstance(name, str): | |||
| raise TypeError(f"Expect string, got : {name}") | |||
| @@ -91,12 +109,21 @@ class CallWith(CallWith_): | |||
| r""" | |||
| Express a primitive CNode. | |||
| """ | |||
| def __init__(self, prim_pattern, inputs=None, should_replace=False): | |||
| def __init__(self, prim_pattern, inputs=None, should_replace=True): | |||
| r""" | |||
| Args: | |||
| prim_pattern (Pattern/Primitive/str): Primitive ValueNode in the Primitive CNode. | |||
| inputs (list/tuple): Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; | |||
| if specified, input patterns should be of right order. | |||
| prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`, | |||
| :class:`mindspore.ops.Primitive`]): Primitive ValueNode in the Primitive CNode. | |||
| inputs (Union[list[:class:`mindspore.graph_utils.graph_pattern`], | |||
| tuple[:class:`mindspore.graph_utils.graph_pattern`]]): | |||
| Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input | |||
| patterns should be of right order and each element should be one of the exposed Pattern instance. | |||
| should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is | |||
| used when building the replacement target node. Use captured node if True, build from scratch otherwise. | |||
| Default: True. | |||
| Raises: | |||
| TypeError: raise type error for invalid argument. | |||
| """ | |||
| if not isinstance(prim_pattern, (Pattern, str, Primitive)): | |||
| raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}") | |||
| @@ -110,17 +137,23 @@ class CallWith(CallWith_): | |||
| raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}") | |||
| CallWith_.__init__(self, self.prim_pattern, self.inputs, should_replace) | |||
| class IsNot(IsNot_): | |||
| r""" | |||
| Express a pattern which forbids a list of patterns. | |||
| NOTE: IsNot pattern should not be the root pattern. | |||
| NOTE: | |||
| IsNot pattern should not be the root pattern. | |||
| """ | |||
| def __init__(self, patterns=None, should_replace=True): | |||
| r""" | |||
| Args: | |||
| patterns(list/tuple): list of forbiden patterns | |||
| patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element | |||
| should be one of the exposed Pattern instance. | |||
| should_replace(bool): added this for interface consistency. Should only set this in sub-patterns. | |||
| Raises: | |||
| ValueError: raise if should_replace is False. | |||
| TypeError: raise type error for invalid argument. | |||
| """ | |||
| if not should_replace: | |||
| raise ValueError("IsNot pattern does not have its own should_replace attribute. Set should_replace in \ | |||
| @@ -142,13 +175,48 @@ class NewTensor(NewTensor_): | |||
| def __init__(self, input_tensor, should_replace=False): | |||
| r""" | |||
| Args: | |||
| input_tensor(Tensor): new tensor to be used in the target | |||
| input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target | |||
| should_replace(bool): added this for interface consistency. NewTensor should only appear in the target. | |||
| Raises: | |||
| ValueError: raise if should_replace is True | |||
| TypeError: raise type error for invalid argument. | |||
| """ | |||
| if should_replace: | |||
| raise ValueError("NewTensor should only appear in the target, thus should_replace can onlyu be False.") | |||
| raise ValueError("NewTensor should only appear in the target, thus should_replace can only be False.") | |||
| self.input_tensor = input_tensor | |||
| if isinstance(input_tensor, Tensor): | |||
| NewTensor_.__init__(self, input_tensor) | |||
| else: | |||
| raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}") | |||
| class NewParameter(NewParameter_): | |||
| r""" | |||
| New Parameter to be used in the target. | |||
| """ | |||
| def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False, should_replace=False): | |||
| r""" | |||
| Args: | |||
| para_name(str): name for the new Parameter | |||
| default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter | |||
| requires_grad(bool): True if the parameter requires gradient. Default: True | |||
| layerwise_parallel(bool): switch for layerwise parallel mode. Default: False | |||
| should_replace(bool): gen new parameter once and replace after if set to be true; otherwise build a new | |||
| parameter everytime a pass target got built. Default: False | |||
| Raises: | |||
| TypeError: raise type error for invalid argument. | |||
| """ | |||
| self.para_name = para_name | |||
| self.default_tensor = default_tensor | |||
| self.requires_grad = requires_grad | |||
| self.layerwise_parallel = layerwise_parallel | |||
| self.should_replace = should_replace | |||
| if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\ | |||
| isinstance(layerwise_parallel, bool) and isinstance(should_replace, bool): | |||
| NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad, | |||
| self.layerwise_parallel, self.should_replace) | |||
| else: | |||
| raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \ | |||
| layerwise_parallel(bool) should_replace(bool), got : {para_name}, {default_tensor}, \ | |||
| {requires_grad}, {layerwise_parallel}, {should_replace}") | |||
| @@ -0,0 +1,24 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Top-level reference to python pass.""" | |||
| from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm | |||
| __all__ = [ | |||
| "registe_pass", | |||
| "unregiste_pass", | |||
| "gen_new_parameter", | |||
| "cancel_new_parameter", | |||
| "set_renorm" | |||
| ] | |||
| @@ -0,0 +1,170 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Python pass register""" | |||
| from inspect import isfunction | |||
| from mindspore.graph_utils.graph_pattern import Pattern, NewParameter | |||
| from mindspore._c_expression import PyPassManager_, phase | |||
| __all__ = [ | |||
| "registe_pass", | |||
| "unregiste_pass", | |||
| "gen_new_parameter", | |||
| "cancel_new_parameter", | |||
| "set_renorm" | |||
| ] | |||
| class PyPassManager(PyPassManager_): | |||
| r""" | |||
| Used to registe and unregiste python passes which can be used to alter graphs. | |||
| Args: | |||
| pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt. | |||
| run_only_once (bool): Specify whether or not to run pass only once. Default: False. | |||
| multigraph (bool): Whether or not the pattern exists across graphs. Default: True. | |||
| Raises: | |||
| TypeError: If argument has invalid type. | |||
| """ | |||
| def __init__(self, pipeline_phase=phase.opt, run_only_once=False): | |||
| if not isinstance(pipeline_phase, phase): | |||
| raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}") | |||
| if not isinstance(run_only_once, bool): | |||
| raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}") | |||
| PyPassManager_.__init__(self) | |||
| self.phase_ = pipeline_phase | |||
| self.run_only_once_ = run_only_once | |||
| def registe(self, py_pass): | |||
| if not isfunction(py_pass): | |||
| raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}") | |||
| pattern, target = py_pass() | |||
| pass_name = py_pass.__name__ | |||
| if not isinstance(pattern, Pattern): | |||
| raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}") | |||
| if not isinstance(target, Pattern): | |||
| raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}") | |||
| super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_) | |||
| def unregiste(self, py_pass, pipeline_phase=phase.opt): | |||
| if not isinstance(pipeline_phase, phase): | |||
| raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}") | |||
| if isinstance(py_pass, str): | |||
| super().unregiste(py_pass, pipeline_phase) | |||
| return | |||
| if isfunction(py_pass): | |||
| super().unregiste(py_pass.__name__, pipeline_phase) | |||
| return | |||
| raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}") | |||
| def __call__(self, py_pass): | |||
| self.registe(py_pass) | |||
| return py_pass | |||
| def gen_new_parameter(self, pattern): | |||
| if not isinstance(pattern, NewParameter): | |||
| raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}") | |||
| super().gen_new_parameter(pattern) | |||
| def set_renorm(self, should_renorm): | |||
| if not isinstance(should_renorm, bool): | |||
| raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}") | |||
| super().set_renorm(should_renorm) | |||
| def registe_pass(pipeline_phase=phase.opt, run_only_once=False): | |||
| """ | |||
| Registe python pass to specified pipeline phase which would be used in compilation. | |||
| Args: | |||
| pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is | |||
| registed. Support phase.resolve and phase.opt. Default: phase.opt. | |||
| run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False. | |||
| Returns: | |||
| This function should be used as a decorator, return the decoratorated pass function. | |||
| Examples: | |||
| >>> from mindspore.graph_utils.graph_pattern import IsPrimTypeOf | |||
| >>> @registe_pass() | |||
| >>> def toy_pass(): | |||
| >>> pattern = IsPrimTypeOf("ReLU") | |||
| >>> target = IsPrimTypeOf("ReLU6") | |||
| >>> return pattern, target | |||
| """ | |||
| return PyPassManager(pipeline_phase, run_only_once) | |||
| def unregiste_pass(py_pass, pipeline_phase=phase.opt): | |||
| """ | |||
| Unregiste python pass. | |||
| Args: | |||
| py_pass(Union(str, function)): target python pass to unregiste. | |||
| pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is | |||
| unregisted. Support phase.resolve and phase.opt. Default: phase.opt. | |||
| """ | |||
| ppm = PyPassManager() | |||
| ppm.unregiste(py_pass, pipeline_phase) | |||
| def gen_new_parameter(pattern): | |||
| """ | |||
| Generate specified parameter every time a network gets compiled. | |||
| NOTE: | |||
| In this way, every pass uses this pattern would be using the same Parameter. If use NewParameter without | |||
| gen_new_parameter, every pass match would build a new Parameter. | |||
| This would registe a pass to add new parameter in the compilation pipeline, so later compilation would | |||
| ALSO add this parameter unless the pass is unregisted. To unregiste this pass, call | |||
| cancel_new_parameter(pattern) | |||
| Args: | |||
| pattern (NewParameter): NewParameter type, could be used to build nested patterns across multiple passes | |||
| after gen_new_parameter. | |||
| Raises: | |||
| TypeError: If argument has invalid type. | |||
| Examples: | |||
| >>> from mindspore.graph_utils.graph_pattern import NewParameter | |||
| >>> abc = NewParameter("abc") | |||
| >>> gen_new_parameter(abc) | |||
| """ | |||
| ppm = PyPassManager() | |||
| ppm.gen_new_parameter(pattern) | |||
| def cancel_new_parameter(pattern): | |||
| """ | |||
| Use with gen_new_parameter to unregiste gen_new_parameter pass. | |||
| Args: | |||
| pattern (NewParameter): NewParameter type, cancel the pass which would add new parameter as this pattern | |||
| describes. | |||
| Examples: | |||
| >>> from mindspore.graph_utils.graph_pattern import NewParameter | |||
| >>> abc = NewParameter("abc") | |||
| >>> gen_new_parameter(abs) | |||
| >>> # some compilations | |||
| >>> cancel_new_parameter(abc) | |||
| """ | |||
| if not isinstance(pattern, NewParameter): | |||
| raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}") | |||
| ppm = PyPassManager() | |||
| ppm.unregiste(pattern.para_name) | |||
| def set_renorm(should_renorm): | |||
| """ | |||
| Set whether or not to do renorm after modified graph in python pass(es). | |||
| """ | |||
| ppm = PyPassManager() | |||
| ppm.set_renorm(should_renorm) | |||
| @@ -152,7 +152,7 @@ class Primitive(Primitive_): | |||
| Check if certain inputs should go to the backend. Subclass in need should override this method. | |||
| Args: | |||
| *args(Primitive args): Same as arguments of current Primitive. | |||
| args(Primitive args): Same as arguments of current Primitive. | |||
| Returns: | |||
| A tuple consisting of two elements. The first element indicates whether we should filter out current | |||
| @@ -19,10 +19,12 @@ import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.python_pass_register import registe_pass, PyPassManager | |||
| from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\ | |||
| cancel_new_parameter | |||
| from mindspore.common.api import _generate_pip_args | |||
| from mindspore._c_expression import generate_key, Executor_ | |||
| from mindspore.common.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor | |||
| from mindspore.graph_utils.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor,\ | |||
| NewParameter, Imm | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| @@ -56,12 +58,39 @@ def test_softmax_relu(): | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | |||
| ppm = PyPassManager() | |||
| ppm.unregiste(softmax_relu_pass) | |||
| unregiste_pass(softmax_relu_pass) | |||
| assert "ReLU" in transformed_repr | |||
| assert "Softmax" not in transformed_repr | |||
| def test_isin_pattern(): | |||
| def test_softmax_relu_sigmoid(): | |||
| """ | |||
| Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)). | |||
| NOTE: | |||
| Sigmoid pattern only exists in the target. | |||
| """ | |||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||
| softmax_model = nn.Softmax() | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_relu_pass(): | |||
| x = AnyPattern() | |||
| softmax_pattern = IsPrimTypeOf(P.Softmax()) | |||
| pattern = CallWith(softmax_pattern, inputs=[x]) | |||
| sigmoid_pattern = IsPrimTypeOf(P.Sigmoid(), should_replace=False) | |||
| call_sigmoid = CallWith(sigmoid_pattern, [x]) | |||
| relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False) | |||
| target = CallWith(relu_pattern, inputs=[call_sigmoid]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3) | |||
| unregiste_pass(softmax_relu_pass) | |||
| assert "ReLU" in transformed_repr | |||
| assert "Sigmoid" in transformed_repr | |||
| assert "Softmax" not in transformed_repr | |||
| def test_isin_pattern_0(): | |||
| """ | |||
| Test IsIn pattern which expresses the IsIn/OneOf semantics. | |||
| """ | |||
| @@ -81,16 +110,41 @@ def test_isin_pattern(): | |||
| target = CallWith(relu6_pattern, inputs=[x]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | |||
| ppm = PyPassManager() | |||
| ppm.unregiste(softmax_relu_pass) | |||
| unregiste_pass(softmax_relu_pass) | |||
| assert "ReLU6" in transformed_repr | |||
| assert "Softmax" not in transformed_repr | |||
| def test_isin_pattern_1(): | |||
| """ | |||
| Test IsIn. IsIn is used as nested inputs for the target in this case. | |||
| """ | |||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||
| softmax_model = nn.Softmax() | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_neg_pass(): | |||
| x = AnyPattern() | |||
| softmax_pattern = IsPrimTypeOf(P.Softmax()) | |||
| call_softmax = CallWith(softmax_pattern, inputs=[x]) | |||
| relu_pattern = IsPrimTypeOf(P.ReLU()) | |||
| call_relu = CallWith(relu_pattern, inputs=[x]) | |||
| pattern = IsIn([call_softmax, call_relu]) | |||
| neg_ops = IsPrimTypeOf(P.Neg(), should_replace=False) | |||
| target = CallWith(neg_ops, inputs=[pattern]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4) | |||
| print(transformed_repr) | |||
| unregiste_pass(softmax_neg_pass) | |||
| assert "Neg" in transformed_repr | |||
| assert "Softmax" in transformed_repr | |||
| def test_isnot_pattern_0(): | |||
| """ | |||
| Test IsNot pattern which expresses the IsNot semantics. | |||
| Case: IsNot pass failed to match | |||
| """ | |||
| set_renorm(False) | |||
| class ConvBN(nn.Cell): | |||
| def __init__(self): | |||
| super(ConvBN, self).__init__() | |||
| @@ -132,11 +186,11 @@ def test_isnot_pattern_0(): | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5) | |||
| ppm = PyPassManager() | |||
| ppm.unregiste(single_bn_pass) | |||
| ppm.unregiste(bn_pass) | |||
| unregiste_pass(single_bn_pass) | |||
| unregiste_pass(bn_pass) | |||
| assert "ReLU6" not in transformed_repr | |||
| assert "Softmax" in transformed_repr | |||
| set_renorm(True) | |||
| def test_isnot_pattern_1(): | |||
| """ | |||
| @@ -160,12 +214,15 @@ def test_isnot_pattern_1(): | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | |||
| ppm = PyPassManager() | |||
| ppm.unregiste(single_bn_pass) | |||
| unregiste_pass(single_bn_pass) | |||
| assert "ReLU6" in transformed_repr | |||
| assert "Softmax" not in transformed_repr | |||
| def test_newtensor_pattern(): | |||
| """ | |||
| Test NewTensor pattern in the target | |||
| """ | |||
| set_renorm(False) | |||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||
| softmax_model = nn.Softmax() | |||
| @@ -181,7 +238,84 @@ def test_newtensor_pattern(): | |||
| target = CallWith(addn_ops, inputs=[x, new_weight], should_replace=False) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | |||
| ppm = PyPassManager() | |||
| ppm.unregiste(softmax_addn_pass) | |||
| unregiste_pass(softmax_addn_pass) | |||
| assert "AddN" in transformed_repr | |||
| assert "Softmax" not in transformed_repr | |||
| set_renorm(True) | |||
| def test_newparameter_pattern(): | |||
| """ | |||
| Test NewParameter pattern in the target | |||
| """ | |||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||
| softmax_model = nn.Softmax() | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_addn_pass(): | |||
| x = AnyPattern() | |||
| softmax = P.Softmax() | |||
| pattern = CallWith(softmax, inputs=[x]) | |||
| default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32) | |||
| default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32) | |||
| new_para_0 = NewParameter("Merlin", default_tensor0) | |||
| new_para_1 = NewParameter("Arthur", default_tensor1) | |||
| target_0 = CallWith(P.MatMul(), inputs=[new_para_0, new_para_1], should_replace=False) | |||
| target = CallWith("make_tuple", inputs=[target_0], should_replace=False) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | |||
| print(transformed_repr) | |||
| unregiste_pass(softmax_addn_pass) | |||
| assert "MatMul" in transformed_repr | |||
| assert "make_tuple" in transformed_repr | |||
| assert "Softmax" not in transformed_repr | |||
| def test_imm_pattern(): | |||
| """ | |||
| Test NewParameter pattern in the target | |||
| """ | |||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||
| softmax_model = nn.Softmax() | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_addn_pass(): | |||
| x = AnyPattern() | |||
| softmax = P.Softmax() | |||
| pattern = CallWith(softmax, inputs=[x]) | |||
| imm = Imm(0) | |||
| target_0 = CallWith("make_tuple", inputs=[pattern], should_replace=False) | |||
| target = CallWith("tuple_getitem", inputs=[target_0, imm], should_replace=False) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | |||
| print(transformed_repr) | |||
| unregiste_pass(softmax_addn_pass) | |||
| assert "make_tuple" in transformed_repr | |||
| assert "tuple_getitem" in transformed_repr | |||
| assert "Softmax" in transformed_repr | |||
| def test_gen_new_parameter(): | |||
| """ | |||
| Test gen_new_parameter | |||
| """ | |||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||
| softmax_model = nn.Softmax() | |||
| default_tensor = Tensor(np.ones((4, 4)), mindspore.float32) | |||
| new_para = NewParameter("Merlin", default_tensor, should_replace=True) | |||
| gen_new_parameter(new_para) | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_make_tuple_pass(): | |||
| x = AnyPattern() | |||
| softmax = P.Softmax() | |||
| pattern = CallWith(softmax, inputs=[x]) | |||
| target = CallWith("make_tuple", inputs=[pattern, new_para], should_replace=False) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | |||
| print(transformed_repr) | |||
| assert "Merlin" in transformed_repr | |||
| unregiste_pass(softmax_make_tuple_pass) | |||
| cancel_new_parameter(new_para) | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | |||
| print(transformed_repr) | |||
| assert "Merlin" not in transformed_repr | |||