/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "pre_activate/common/optimizer.h" #include #include #include #include #include #include #include #include "pre_activate/common/pass_manager.h" #include "session/anf_runtime_algorithm.h" #include "ir/manager.h" namespace mindspore { namespace opt { namespace { AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph); ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { if (utils::isa(sexp)) { return NewValueNode(utils::cast(sexp)); } if (utils::isa(sexp)) { return NewValueNode(utils::cast(sexp)); } if (utils::isa(sexp)) { return NewValueNode(utils::cast(sexp)); } if (utils::isa(sexp)) { return NewValueNode(utils::cast(sexp)); } return nullptr; } CNodePtr CreateCNodeWithGraph(const std::vector &input_nodes, const BaseRef &graph) { if (utils::isa(graph)) { return std::make_shared(input_nodes, utils::cast(graph)); } if (utils::isa(graph)) { return std::make_shared(input_nodes, utils::cast(graph)); } return nullptr; } VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { if (utils::isa(graph)) { MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); return std::make_shared(utils::cast(sexp), nullptr); } if (utils::isa(graph)) { MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); return std::make_shared(utils::cast(sexp), utils::cast(graph)); } MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); return nullptr; } AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph = false) { MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); MS_EXCEPTION_IF_NULL(primitive_vars); if (utils::isa(sexp)) { return HandleSexpVector(sexp, graph, primitive_vars, multigraph); } if (utils::isa(sexp)) { auto var_ptr = utils::cast(sexp); MS_EXCEPTION_IF_NULL(var_ptr); if (var_ptr->primitive()) { (*primitive_vars)[var_ptr->primitive()] = var_ptr; return NewValueNode(var_ptr->primitive()); } return CreateVarNodeWithSexp(sexp, graph); } if (utils::isa(sexp)) { return utils::cast(sexp); } auto value_node = CreateValueNodeWithSexp(sexp); if (value_node == nullptr) { MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString(); } return value_node; } AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); std::vector input_nodes; const auto &tuple = utils::cast(sexp); if (multigraph && utils::isa(graph)) { for (auto &x : tuple) { AnfNodePtr node = SexpToNode(x, std::make_shared("G"), primitive_vars, true); input_nodes.push_back(node); } VarPtr var_ptr = utils::cast(graph); return std::make_shared(input_nodes, var_ptr); } for (auto &x : tuple) { AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); input_nodes.push_back(node); } return CreateCNodeWithGraph(input_nodes, graph); } } // namespace static bool AnfEqual(const BaseRef &a, const BaseRef &b) { if (utils::isa(a) && utils::isa(b)) { auto a_node = utils::cast(a); auto b_node = utils::cast(b); if (IsValueNode(a_node) && IsValueNode(b_node)) { auto a_value_node = a_node->cast(); auto a_value = a_value_node->value(); auto a_prim = a_value->cast(); auto b_value_node = b_node->cast(); auto b_value = b_value_node->value(); auto b_prim = b_value->cast(); return a_prim->name() == b_prim->name(); } else if (a_node->isa() && b_node->isa()) { auto a_value_node_ptr = a_node->cast(); if (a_value_node_ptr == nullptr) { MS_LOG(EXCEPTION) << "cast value node ptr fail"; } auto a_value_ptr = a_value_node_ptr->value(); if (a_value_ptr == nullptr) { MS_LOG(EXCEPTION) << "value ptr is nullptr"; } auto b_value_node_ptr = b_node->cast(); if (b_value_node_ptr == nullptr) { MS_LOG(EXCEPTION) << "cast value node ptr fail"; } auto b_value_ptr = b_value_node_ptr->value(); if (b_value_ptr == nullptr) { MS_LOG(EXCEPTION) << "value ptr is nullptr"; } return (*a_value_ptr) == (*b_value_ptr); } MS_LOG(DEBUG) << "check AnfNodePtr equal"; } if (utils::isa(a) && utils::isa(b)) { MS_LOG(DEBUG) << "check GraphPtr equal"; } return a == b; } static bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { // To matchCNode and Kernel's type if (utils::isa(a) && utils::isa(b)) { return true; } return a.type() == b.type(); } PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph) : NodePass(name), multigraph_(multigraph), pattern_engine_(PatternEngine(std::make_shared(), std::function(AnfEqual), std::function(CNodeTypeEqual))), primitive_vars_(std::make_shared()) {} const BaseRef PatternProcessPass::DefinePattern() const { VarPtr X = std::make_shared(); return BaseRef({X}); } void PatternProcessPass::Build() { VarPtr fg = std::make_shared("RootG"); BaseRef pattern = std::move(DefinePattern()); pattern_ = SexpToNode(pattern, fg, primitive_vars_.get(), multigraph_); } AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { if (pattern_ == nullptr) { Build(); } auto empty_equiv = std::make_shared(); MS_EXCEPTION_IF_NULL(primitive_vars_); EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv); if (equiv != nullptr && !equiv->empty()) { return Process(func_graph, node, equiv); } return nullptr; } void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) { if (pass_manager != nullptr) { pass_managers_.push_back(pass_manager); } } FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_only_once) { MS_EXCEPTION_IF_NULL(func_graph); run_only_once_ = (pass_managers_.size() == 1) ? true : run_only_once; auto manager = func_graph->manager(); if (manager == nullptr) { manager = Manage(func_graph, false); func_graph->set_manager(manager); } bool changed = true; while (changed) { changed = false; for (size_t i = 0; i < pass_managers_.size(); ++i) { const PassManagerPtr &pm = pass_managers_[i]; if (pm != nullptr && pm->Run(func_graph)) { changed = true; } } if (run_only_once_) { break; } } std::vector func_graphs; func_graphs.push_back(func_graph); manager->KeepRoots(func_graphs); (void)TopoSort(func_graph->get_return()); return func_graph; } } // namespace opt } // namespace mindspore