/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "optimizer/opt.h" #include #include #include #include #include "ir/anf.h" #include "ir/manager.h" #include "utils/ordered_set.h" #include "utils/log_adapter.h" #include "optimizer/optimizer.h" namespace mindspore { /* namespace to support opt */ namespace opt { SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, const RenormAction &renorm_action) { auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; return std::make_shared(transform, name, fn, renorm_action); } SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const std::vector &prims, const RenormAction &renorm_action) { auto fn = [prims](const AnfNodePtr &node) -> bool { if (!node->isa()) { return false; } for (auto &prim : prims) { if (IsPrimitiveCNode(node, prim)) { return true; } } return false; }; return std::make_shared(transform, name, fn, renorm_action); } SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, const RenormAction &renorm_action) { return std::make_shared(transform, name, predicate, renorm_action); } AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const { #ifdef ENABLE_PROFILE double t = GetTime(); #endif AnfNodePtr result = transform_(optimizer, node); #ifdef ENABLE_PROFILE if (optimizer != nullptr) { auto time = GetTime(); MsProfile::StatTime("substitution." + name_, time - t); if (result != nullptr) { MsProfile::StatTime("match." + name_, time - t); } } #endif if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) { if (renorm_action_ == FORCE_RENORM) { optimizer->add_node_to_renormalize(result); } else { // renorm_action_ is CHECK_RENORM if (result->abstract() == nullptr) { optimizer->add_node_to_renormalize(result); } } } return result; } bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, const SubstitutionPtr &transform) const { FuncGraphManagerPtr manager = optimizer->manager(); std::unordered_set seen_node; std::deque todo{root_node}; bool changes = false; while (!todo.empty()) { AnfNodePtr node = todo.front(); todo.pop_front(); // check whether this node has been matched. if (seen_node.find(node) != seen_node.end() || !manager->all_nodes().contains(node)) { continue; } (void)seen_node.insert(node); // select nodes that this transform can be applied. bool is_match = transform->predicate_(node); // apply transform on this node bool change = false; if (is_match) { auto ret = (*transform)(optimizer, node); if (ret != nullptr && ret != node) { change = true; #ifdef ENABLE_PROFILE double t = GetTime(); #endif (void)manager->Replace(node, ret); #ifdef ENABLE_PROFILE MsProfile::StatTime("replace." + transform->name_, GetTime() - t); #endif node = ret; } } // find success, and add them to todo list if (IsValueNode(node)) { todo.push_back(GetValueNode(node)->output()); } if (node->isa()) { auto &inputs = node->cast()->inputs(); (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); } auto &node_users = manager->node_users(); if (change && node_users.find(node) != node_users.end()) { for (auto &use : node_users[node]) { auto use_node = use.first; todo.push_back(use_node); if (seen_node.find(use_node) != seen_node.end()) { (void)seen_node.erase(use_node); } } } changes = changes || change; } return changes; } bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { MS_EXCEPTION_IF_NULL(optimizer); MS_EXCEPTION_IF_NULL(func_graph); FuncGraphManagerPtr manager = optimizer->manager(); manager->AddFuncGraph(func_graph); bool loop = false; bool changes = false; do { loop = false; for (auto const &transform : list_) { auto change = ApplyTransform(optimizer, func_graph->output(), transform); changes = changes || change; loop = loop || change; } if (is_once_) { break; } } while (loop); return changes; } } // namespace opt } // namespace mindspore