/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "ir/func_graph.h" #include #include #include #include "debug/trace.h" #include "ir/manager.h" #include "operator/ops.h" #include "pybind_api/export_flags.h" #include "utils/ordered_set.h" #include "utils/convert_utils_base.h" namespace mindspore { /* * Methods of Graph */ FuncGraph::FuncGraph() : flags_(), transforms_(), parameter_default_value_(), seen_(0), parameters_(), has_vararg_(false), has_kwarg_(false), kwonlyargs_count_(0), hyper_param_count_(0), is_generated_(false), return_(nullptr), manager_(std::weak_ptr()) { debug_info_ = std::make_shared(); } AnfNodePtr FuncGraph::output() const { // If return value is set, return should have two inputs. if (return_ != nullptr && return_->inputs().size() == 2) { return return_->input(1); } else { // If not set yet, return nullptr. return nullptr; } } ParameterPtr FuncGraph::add_parameter() { FuncGraphPtr this_func_graph = shared_from_base(); ParameterPtr p = std::make_shared(this_func_graph); add_parameter(p); return p; } void FuncGraph::add_parameter(const ParameterPtr &p) { if (manager_.lock()) { std::vector new_params = parameters_; new_params.push_back(p); manager_.lock()->SetParameters(shared_from_base(), new_params); } else { parameters_.push_back(p); } } ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { FuncGraphPtr this_graph = shared_from_base(); ParameterPtr p = std::make_shared(this_graph); p->set_name(name); p->debug_info()->set_name(name); std::vector new_params = parameters_; // append parameter new_params.push_back(p); if (manager_.lock()) { manager_.lock()->SetParameters(shared_from_base(), new_params); } else { parameters_.push_back(p); } hyper_param_count_++; return p; } bool FuncGraph::has_flag(const std::string &flag) { if (flags_.count(flag)) { return flags_[flag]; } return false; } CNodePtr FuncGraph::NewCNode(const std::vector &inputs) { CNodePtr cnode = std::make_shared(inputs, shared_from_base()); if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { order_.push_back(cnode); MS_LOG(INFO) << "Graph: " << ToString() << ", push back " << cnode->DebugString() << " in order."; } return cnode; } CNodePtr FuncGraph::NewCNodeWithScope(const std::vector &inputs, const ScopePtr &scope) { CNodePtr app = NewCNode(inputs); app->set_scope(scope); return app; } void FuncGraph::DumpCNodeList() { MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:"; for (const auto &cnode : order_) { MS_LOG(INFO) << cnode->DebugString(); } } std::string FuncGraph::ToString() const { return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); } GraphDebugInfoPtr FuncGraph::debug_info() { MS_EXCEPTION_IF_NULL(this->debug_info_); if (this->debug_info_->get_graph() == nullptr) { this->debug_info_->set_graph(shared_from_base()); } return this->debug_info_; } const AnfNodeSet &FuncGraph::nodes() { return nodes_; } void FuncGraph::CopyNodes(const FuncGraphPtr &source) { nodes_ = source->nodes(); } void FuncGraph::ClearNodes() { nodes_.clear(); } void FuncGraph::AddNode(AnfNodePtr node) { nodes_.add(node); } void FuncGraph::DropNode(AnfNodePtr node) { nodes_.erase(node); auto graph = node->func_graph(); // Remove the node from order list. if (graph) { graph->EraseUnusedNodeInOrder(node); } } const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; } void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) { auto &others = source->value_nodes(); for (auto it = others.begin(); it != others.end(); it++) { AddValueNode(it->first, it->second); } } void FuncGraph::ClearValueNodes() { value_nodes_.clear(); } void FuncGraph::AddValueNode(AnfNodePtr node, int count) { if (value_nodes_.count(node) == 0) { value_nodes_[node] = count; } else { value_nodes_[node] += count; } } void FuncGraph::DropValueNode(AnfNodePtr node) { if (value_nodes_.count(node) != 0) { if (value_nodes_[node] == 1) { (void)value_nodes_.erase(node); } else { value_nodes_[node]--; if (value_nodes_[node] < 0) { MS_LOG(EXCEPTION) << "Count of ValueNode '" << node << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); } } } } const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; } void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) { auto &others = source->free_variables(); for (auto it = others.begin(); it != others.end(); it++) { if (it->first->func_graph().get() != this) { (void)AddFreeVariable(it->first, it->second); } } } void FuncGraph::ClearFreeVariables() { free_variables_.clear(); } bool FuncGraph::AddFreeVariable(AnfNodePtr node, int count) { if (free_variables_.count(node) == 0) { free_variables_[node] = count; return true; } else { free_variables_[node] += count; return false; } } bool FuncGraph::DropFreeVariable(AnfNodePtr node) { if (free_variables_.count(node) != 0) { if (free_variables_[node] == 1) { (void)free_variables_.erase(node); return true; } else { free_variables_[node]--; if (free_variables_[node] < 0) { MS_LOG(EXCEPTION) << "Count of free variable '" << node << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); } } } return false; } const BaseRefCounterMap &FuncGraph::free_variables_total() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); auto &fv_total = mng->free_variables_total(); return fv_total[shared_from_base()]; } std::vector FuncGraph::free_variables_nodes() { std::vector nodes; const auto &fv_total = this->free_variables_total(); for (auto &p : fv_total) { auto key = p.first; if (utils::isa(key)) { nodes.push_back(utils::cast(key)); } } return nodes; } std::vector FuncGraph::free_variables_func_graphs() { std::vector func_graphs; const auto &fv_total = this->free_variables_total(); for (auto &p : fv_total) { auto key = p.first; if (utils::isa(key)) { func_graphs.push_back(utils::cast(key)); } } return func_graphs; } const FuncGraphCounterMap &FuncGraph::func_graphs_used() { return func_graphs_used_; } void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) { auto &others = source->func_graphs_used(); for (auto it = others.begin(); it != others.end(); it++) { (void)AddFuncGraphUsed(it->first, it->second); } func_graphs_used_.erase(source); } void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); } bool FuncGraph::AddFuncGraphUsed(FuncGraphPtr fg, int count) { if (func_graphs_used_.count(fg) == 0) { func_graphs_used_[fg] = count; return true; } else { func_graphs_used_[fg] += count; return false; } } bool FuncGraph::DropFuncGraphUsed(FuncGraphPtr fg) { if (func_graphs_used_.count(fg) != 0) { if (func_graphs_used_[fg] == 1) { (void)func_graphs_used_.erase(fg); return true; } else { func_graphs_used_[fg]--; if (func_graphs_used_[fg] < 0) { MS_LOG(EXCEPTION) << "Count of FuncGraph '" << fg << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); } } } return false; } const FuncGraphSet &FuncGraph::func_graphs_used_total() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); auto &used = mng->func_graphs_used_total(shared_from_base()); return used; } const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; } void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) { auto &others = source->func_graph_cnodes_index(); for (auto it = others.begin(); it != others.end(); it++) { // Ignore the user graph who may own itself. auto fg = it->first->first->func_graph(); MS_EXCEPTION_IF_NULL(fg); if (fg.get() != this) { AddFuncGraphCNodeIndex(it->first, it->second); } } } void FuncGraph::ClearFuncGraphCNodesIndex() { func_graph_cnodes_index_.clear(); } void FuncGraph::AddFuncGraphCNodeIndex(CNodeIndexPairPtr pair, int count) { if (func_graph_cnodes_index_.count(pair) == 0) { func_graph_cnodes_index_[pair] = count; } else { func_graph_cnodes_index_[pair] += count; } } void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) { if (func_graph_cnodes_index_.count(pair) != 0) { if (func_graph_cnodes_index_[pair] == 1) { (void)func_graph_cnodes_index_.erase(pair); } else { func_graph_cnodes_index_[pair]--; if (func_graph_cnodes_index_[pair] < 0) { MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); } } } } const FuncGraphCounterMap &FuncGraph::j_func_graphs() { return j_func_graphs_; } void FuncGraph::CopyJFuncGraphs(const FuncGraphPtr &source) { auto &others = source->j_func_graphs(); for (auto it = others.begin(); it != others.end(); it++) { AddJFuncGraph(it->first, it->second); } } void FuncGraph::ClearJFuncGraphs() { j_func_graphs_.clear(); } void FuncGraph::AddJFuncGraph(FuncGraphPtr fg, int count) { if (j_func_graphs_.count(fg) == 0) { j_func_graphs_[fg] = count; } else { j_func_graphs_[fg] += count; } } void FuncGraph::DropJFuncGraph(FuncGraphPtr fg) { if (j_func_graphs_.count(fg) != 0) { if (j_func_graphs_[fg] == 1) { (void)j_func_graphs_.erase(fg); } else { j_func_graphs_[fg]--; if (j_func_graphs_[fg] < 0) { MS_LOG(EXCEPTION) << "Count of J FuncGraph '" << fg << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); } } } } FuncGraphPtr FuncGraph::parent() { // report the bug early. if (manager_.lock() == nullptr) { MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString() << " NodeInfo: " << trace::GetDebugInfo(debug_info()); } auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); return mng->parent(shared_from_base()); } const FuncGraphSet &FuncGraph::children() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); return mng->children(shared_from_base()); } const FuncGraphSet &FuncGraph::scope() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); return mng->scopes(shared_from_base()); } bool FuncGraph::recursive() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); return mng->recursive(shared_from_base()); } std::shared_ptr> FuncGraph::recursive_graphs() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); return mng->recursive_graphs(shared_from_base()); } AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { auto itr = this->parameter_default_value_.find(name); if (itr == parameter_default_value_.end()) { return nullptr; } auto default_value = itr->second; if (default_value == nullptr) { MS_LOG(EXCEPTION) << "Graph parameter " << name << " not exist"; } if (IsValueNode(default_value)) { return nullptr; } return default_value; } // set the default values void FuncGraph::SetDefaultValues(const std::vector &name_list, const std::vector &value_list) { auto all_is_null = std::all_of(value_list.begin(), value_list.end(), [](const AnfNodePtr &node) { return IsValueNode(node); }); if (value_list.empty()) { all_is_null = true; } for (size_t i = 0; i < name_list.size(); ++i) { if (!all_is_null) { this->parameter_default_value_[name_list[i]] = value_list[i]; } } } void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); } size_t FuncGraph::GetDefaultValueCount() { int null_count = std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(), [](const std::pair &pair) { return IsValueNode(pair.second); }); return parameter_default_value_.size() - IntToSize(null_count); } AnfNodePtr FuncGraph::GetVariableArgParameter() { if (!has_vararg_) { return nullptr; } if (has_kwarg_) { if (parameters_.size() < hyper_param_count_ + 2) { MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count"; } return parameters_[parameters_.size() - hyper_param_count_ - 2]; } if (parameters_.size() < hyper_param_count_ + 1) { MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; } return parameters_[parameters_.size() - hyper_param_count_ - 1]; } std::string FuncGraph::GetVariableArgName() { if (!has_vararg_) { return ""; } if (has_kwarg_) { if (parameters_.size() < hyper_param_count_ + 2) { MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count"; } return parameters_[parameters_.size() - hyper_param_count_ - 2]->cast()->name(); } if (parameters_.size() < hyper_param_count_ + 1) { MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; } return parameters_[parameters_.size() - hyper_param_count_ - 1]->cast()->name(); } AnfNodePtr FuncGraph::GetVariableKwargParameter() { if (has_kwarg_) { if (parameters_.size() < hyper_param_count_ + 1) { MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; } return parameters_[parameters_.size() - hyper_param_count_ - 1]; } return nullptr; } std::string FuncGraph::GetVariableKwargName() { if (has_kwarg_) { if (parameters_.size() < hyper_param_count_ + 1) { MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; } return parameters_[parameters_.size() - hyper_param_count_ - 1]->cast()->name(); } return ""; } int FuncGraph::GetPositionalArgsCount() const { int count = SizeToInt(parameters_.size()); if (has_kwarg_) { count--; } if (has_vararg_) { count--; } return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_); } AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) { for (size_t i = 0; i < parameters_.size(); ++i) { MS_EXCEPTION_IF_NULL(parameters_[i]); auto param_cast = parameters_[i]->cast(); MS_EXCEPTION_IF_NULL(param_cast); if (param_cast->name() == name) { return parameters_[i]; } } return nullptr; } void FuncGraph::add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); } std::list FuncGraph::GetOrderedCnodes() { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { MS_LOG(DEBUG) << "Return ordered cnodes."; return order_; } else { auto this_ptr = shared_from_base(); auto BelongSameGraph = std::bind(IncludeBelongGraph, this_ptr, std::placeholders::_1); auto SuccDepends = std::bind(SuccIncludeFV, this_ptr, std::placeholders::_1); std::list cnodes; auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph); for (const auto &node : nodes) { auto cnode = dyn_cast(node); if (cnode) { cnodes.push_back(cnode); } } return cnodes; } } void FuncGraph::EraseUnusedNodeInOrder() { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { auto mng = manager_.lock(); if (mng) { auto &all_nodes = nodes(); // Erase unused cnode. for (auto it = order_.begin(); it != order_.end();) { if (all_nodes.count(*it)) { (void)it++; } else { MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order."; it = order_.erase(it); } } } } } void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &n) { if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa()) { order_.remove(n->cast()); MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list."; } } void FuncGraph::CheckOrder() { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { MS_LOG(DEBUG) << "Check graph " << ToString(); for (auto it = order_.begin(); it != order_.end(); (void)it++) { for (const auto &input_node : (*it)->inputs()) { if (input_node && input_node->isa() && input_node->func_graph() == shared_from_base()) { // Need to reorder the wrong order node. auto found = std::find(order_.begin(), it, input_node); if (found == it) { DumpCNodeList(); MS_LOG(EXCEPTION) << "The cnode " << (*it)->DebugString() << " order in " << ToString() << " doesn't obey the input dependency, " << "as input " << input_node->DebugString() << " is not ahead of itself."; } } } } auto mng = manager_.lock(); if (mng != nullptr) { const auto &all_nodes = nodes(); if (all_nodes.size() != (order_.size() + parameters_.size())) { DumpCNodeList(); MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " << all_nodes.size() - parameters_.size() << "."; } } MS_LOG(DEBUG) << "Check order okay."; } } size_t NewFgSeenGeneration() { static size_t fg_seen_generation = 0; return ++fg_seen_generation; } const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared("FuncGraph"); const char kFuncGraphFlagUndetermined[] = "Undeterminate"; } // namespace mindspore