/** * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "frontend/optimizer/opt.h" #include #include #include #include #include "ir/anf.h" #include "ir/manager.h" #include "frontend/optimizer/optimizer.h" #include "utils/log_adapter.h" namespace mindspore { /* namespace to support opt */ namespace opt { SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, const RenormAction &renorm_action) { auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; return std::make_shared(transform, name, fn, renorm_action); } SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const std::vector &prims, const RenormAction &renorm_action) { auto fn = [prims](const AnfNodePtr &node) -> bool { if (!node->isa()) { return false; } auto cnode = node->cast(); auto inp0 = cnode->input(0); auto prim0 = GetValueNode(inp0); if (prim0 == nullptr) { return false; } auto hash = prim0->Hash(); auto const &name = prim0->name(); for (auto &prim : prims) { if (hash == prim->Hash() && name == prim->name()) { return true; } } return false; }; return std::make_shared(transform, name, fn, renorm_action); } SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, const RenormAction &renorm_action) { return std::make_shared(transform, name, predicate, renorm_action); } AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { #ifdef ENABLE_PROFILE double t = GetTime(); #endif AnfNodePtr result = (*transform_)(optimizer, node); #ifdef ENABLE_PROFILE if (optimizer != nullptr) { auto time = GetTime(); MsProfile::StatTime("substitution." + name_, time - t); if (result != nullptr) { MsProfile::StatTime("match." + name_, time - t); } } #endif if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) { if ((renorm_action_ == FORCE_RENORM) || (result->abstract() == nullptr)) { optimizer->set_is_untyped_generated(); } } return result; } static bool isTraversable(const AnfNodePtr &node) { if (node == nullptr) { return false; } if (node->isa() || node->isa()) { return true; } if (IsValueNode(node) || IsValueNode(node)) { return true; } return false; } static inline AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &substitution) { auto manager = optimizer->manager(); bool is_match = substitution->predicate_(node); if (is_match) { TraceGuard trace_guard(std::make_shared(node->debug_info())); auto res = (*substitution)(optimizer, node); if (res != nullptr && res != node) { #ifdef ENABLE_PROFILE double t = GetTime(); #endif MS_LOG(DEBUG) << "Replace " << node->DebugString() << " with " << res->DebugString() << ", by " << substitution->name_; (void)manager->Replace(node, res); #ifdef ENABLE_PROFILE MsProfile::StatTime("replace." + substitution->name_, GetTime() - t); #endif return res; } } return nullptr; } static inline void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node, std::deque *todo, bool change, size_t seen) { if (IsValueNode(node)) { (*todo).emplace_back(GetValueNode(node)->output()); } if (node->isa()) { auto &inputs = node->cast()->inputs(); (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo)); } if (!change) { return; } auto manager = optimizer->manager(); auto &node_users = manager->node_users(); auto users_iterator = node_users.find(node); if (users_iterator == node_users.end()) { return; } auto users = users_iterator->second; for (auto &use : users) { auto use_node = use.first; if (use_node == nullptr) { continue; } (*todo).emplace_back(use_node); if (use_node->seen_ == seen) { use_node->seen_--; } } } bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const { #ifdef ENABLE_PROFILE double start = GetTime(); #endif FuncGraphManagerPtr manager = optimizer->manager(); auto seen = NewSeenGeneration(); // 1024 is for the initial capacity of deque std::deque todo(1024); todo.clear(); todo.emplace_back(func_graph->output()); bool changes = false; auto &all_nodes = manager->all_nodes(); while (!todo.empty()) { AnfNodePtr node = todo.front(); todo.pop_front(); if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { continue; } node->seen_ = seen; bool change = false; for (auto &substitution : list_) { auto res = DoTransform(optimizer, node, substitution); if (res != nullptr) { change = true; changes = true; node = res; todo.emplace_back(res); break; } } UpdateTransformingList(optimizer, node, &todo, change, seen); } #ifdef ENABLE_PROFILE MsProfile::StatTime("opt.transforms." + optimizer->name(), GetTime() - start); #endif return changes; } bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, const SubstitutionPtr &substitution) const { #ifdef ENABLE_PROFILE double start = GetTime(); #endif FuncGraphManagerPtr manager = optimizer->manager(); auto seen = NewSeenGeneration(); // 1024 is for the initial capacity of deque std::deque todo(1024); todo.clear(); todo.emplace_back(root_node); bool changes = false; auto &all_nodes = manager->all_nodes(); while (!todo.empty()) { AnfNodePtr node = todo.front(); todo.pop_front(); if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { continue; } node->seen_ = seen; bool change = false; auto res = DoTransform(optimizer, node, substitution); if (res != nullptr) { change = true; changes = true; node = res; } UpdateTransformingList(optimizer, node, &todo, change, seen); } #ifdef ENABLE_PROFILE MsProfile::StatTime("opt.transform." + optimizer->name(), GetTime() - start); #endif return changes; } bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const { // Add for substitution status counting size_t space = 0; std::unordered_map> status; if (optimizer->is_on_debug_) { for (size_t i = 0; i < list_.size(); i++) { status[list_[i]->name_ + std::to_string(i)] = {}; } } bool changes = false; bool loop = true; while (loop) { loop = false; for (size_t i = 0; i < list_.size(); i++) { const auto &substitution = list_[i]; bool change = ApplySubstitutionToIR(optimizer, func_graph->output(), substitution); changes = changes || change; loop = loop || change; static const auto enable_dump_pass_ir = (common::GetEnv("ENV_DUMP_PASS_IR") == "1"); if (enable_dump_pass_ir && MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { auto fg_name = optimizer->name() + "_r" + std::to_string(optimizer->CurPass_.counter) + "_" + optimizer->CurPass_.name + "_" + substitution->name_; DumpIR(fg_name + ".ir", func_graph); if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode) { func_graph->DumpFuncGraph(fg_name); ExportIR(fg_name + ".dat", "", func_graph); } } // Record the status of each substitution if (optimizer->is_on_debug_) { status[substitution->name_ + std::to_string(i)].push_back(change); space = std::max(substitution->name_.size(), space); } } if (is_once_) { break; } } // Display the status of each substitution if (optimizer->is_on_debug_) { std::stringstream ss; ss << std::endl << "Pass: " << optimizer->name() << "(" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name << std::endl; for (size_t i = 0; i < list_.size(); i++) { auto name = list_[i]->name_; ss << std::left << std::setw(space + 4) << name << "\t"; for (auto change : status[name + std::to_string(i)]) { ss << change << " "; } ss << std::endl; } MS_LOG(DEBUG) << ss.str(); } return changes; } bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { MS_EXCEPTION_IF_NULL(optimizer); MS_EXCEPTION_IF_NULL(func_graph); FuncGraphManagerPtr manager = optimizer->manager(); manager->AddFuncGraph(func_graph); bool changes = false; static const auto traverse_mode = (common::GetEnv("ENV_TRAVERSE_SUBSTITUTIONS_MODE") != "1" ? kOptTraverseFromIRToSubstitutions : kOptTraverseFromSubstitutionsToIR); if (traverse_mode == kOptTraverseFromIRToSubstitutions && MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode && optimizer->traverse_nodes_first() && !is_once_ && !global_sensitive_) { MS_LOG(DEBUG) << "IR >> SUB, " << optimizer->name() << "(r" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name; changes = ApplyIRToSubstitutions(optimizer, func_graph); } else { MS_LOG(DEBUG) << "SUB >> IR, " << optimizer->name() << "(r" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name; changes = ApplySubstitutionsToIR(optimizer, func_graph); } return changes; } } // namespace opt } // namespace mindspore