| @@ -29,6 +29,7 @@ | |||
| #include "utils/visible.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/ordered_set.h" | |||
| #include "utils/ordered_map.h" | |||
| namespace mindspore { | |||
| template <typename T> | |||
| @@ -195,25 +195,88 @@ GraphDebugInfoPtr FuncGraph::debug_info() { | |||
| return this->debug_info_; | |||
| } | |||
| const AnfNodeSet &FuncGraph::nodes() { | |||
| auto mng = manager_.lock(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto &nodes = mng->nodes(); | |||
| return nodes[shared_from_base<FuncGraph>()]; | |||
| const AnfNodeSet &FuncGraph::nodes() { return nodes_; } | |||
| void FuncGraph::CopyNodes(const AnfNodeSet &other_nodes) { nodes_ = other_nodes; } | |||
| void FuncGraph::ClearNodes() { nodes_.clear(); } | |||
| void FuncGraph::AddNode(AnfNodePtr node) { nodes_.add(node); } | |||
| void FuncGraph::DropNode(AnfNodePtr node) { | |||
| nodes_.erase(node); | |||
| auto graph = node->func_graph(); | |||
| // Remove the node from order list. | |||
| if (graph) { | |||
| graph->EraseUnusedNodeInOrder(node); | |||
| } | |||
| } | |||
| const AnfNodeCounterMap &FuncGraph::value_nodes() { | |||
| auto mng = manager_.lock(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto &cts = mng->valuenodes(); | |||
| return cts[shared_from_base<FuncGraph>()]; | |||
| const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; } | |||
| void FuncGraph::CopyValueNodes(const AnfNodeCounterMap &other_value_nodes) { value_nodes_ = other_value_nodes; } | |||
| void FuncGraph::ClearValueNodes() { value_nodes_.clear(); } | |||
| void FuncGraph::AddValueNode(AnfNodePtr node, int count) { | |||
| if (value_nodes_.count(node) == 0) { | |||
| value_nodes_[node] = count; | |||
| } else { | |||
| value_nodes_[node] += count; | |||
| } | |||
| } | |||
| const AnfNodeCounterMap &FuncGraph::free_variables_direct() { | |||
| auto mng = manager_.lock(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto &fv_direct = mng->free_variables_direct(); | |||
| return fv_direct[shared_from_base<FuncGraph>()]; | |||
| void FuncGraph::DropValueNode(AnfNodePtr node) { | |||
| if (value_nodes_.count(node) != 0) { | |||
| if (value_nodes_[node] == 1) { | |||
| (void)value_nodes_.erase(node); | |||
| } else { | |||
| value_nodes_[node]--; | |||
| if (value_nodes_[node] < 0) { | |||
| MS_LOG(EXCEPTION) << "Count of ValueNode '" << node | |||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; } | |||
| void FuncGraph::CopyFreeVariables(const AnfNodeCounterMap &others) { | |||
| auto it = others.begin(); | |||
| for (; it != others.end(); it++) { | |||
| if (it->first->func_graph().get() != this) { | |||
| (void)AddFreeVariable(it->first, it->second); | |||
| } | |||
| } | |||
| } | |||
| void FuncGraph::ClearFreeVariables() { free_variables_.clear(); } | |||
| bool FuncGraph::AddFreeVariable(AnfNodePtr node, int count) { | |||
| if (free_variables_.count(node) == 0) { | |||
| free_variables_[node] = count; | |||
| return true; | |||
| } else { | |||
| free_variables_[node] += count; | |||
| return false; | |||
| } | |||
| } | |||
| bool FuncGraph::DropFreeVariable(AnfNodePtr node) { | |||
| if (free_variables_.count(node) != 0) { | |||
| if (free_variables_[node] == 1) { | |||
| (void)free_variables_.erase(node); | |||
| return true; | |||
| } else { | |||
| free_variables_[node]--; | |||
| if (free_variables_[node] < 0) { | |||
| MS_LOG(EXCEPTION) << "Count of free variable '" << node | |||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); | |||
| } | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| const BaseRefCounterMap &FuncGraph::free_variables_total() { | |||
| @@ -249,11 +312,36 @@ std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() { | |||
| return func_graphs; | |||
| } | |||
| const FuncGraphCounterMap &FuncGraph::func_graphs_used() { | |||
| auto mng = manager_.lock(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto &used = mng->func_graphs_used(); | |||
| return used[shared_from_base<FuncGraph>()]; | |||
| const AnfNodeCounterMap &FuncGraph::func_graph_value_nodes() { return func_graph_value_nodes_; } | |||
| void FuncGraph::CopyFuncGraphValueNodes(const AnfNodeCounterMap &others) { func_graph_value_nodes_ = others; } | |||
| void FuncGraph::ClearFuncGraphValueNodes() { func_graph_value_nodes_.clear(); } | |||
| bool FuncGraph::AddFuncGraphValueNode(AnfNodePtr node, int count) { | |||
| if (func_graph_value_nodes_.count(node) == 0) { | |||
| func_graph_value_nodes_[node] = count; | |||
| return true; | |||
| } else { | |||
| func_graph_value_nodes_[node] += count; | |||
| return false; | |||
| } | |||
| } | |||
| bool FuncGraph::DropFuncGraphValueNode(AnfNodePtr node) { | |||
| if (func_graph_value_nodes_.count(node) != 0) { | |||
| if (func_graph_value_nodes_[node] == 1) { | |||
| (void)func_graph_value_nodes_.erase(node); | |||
| return true; | |||
| } else { | |||
| func_graph_value_nodes_[node]--; | |||
| if (func_graph_value_nodes_[node] < 0) { | |||
| MS_LOG(EXCEPTION) << "Count of value node(FuncGraph) '" << node | |||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); | |||
| } | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| const FuncGraphSet &FuncGraph::func_graphs_used_total() { | |||
| @@ -263,15 +351,68 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() { | |||
| return used; | |||
| } | |||
| const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { | |||
| auto mng = manager_.lock(); | |||
| if (mng == nullptr) { | |||
| MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString() | |||
| << " NodeInfo: " << trace::GetDebugInfo(debug_info()); | |||
| const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; } | |||
| void FuncGraph::CopyFuncGraphCNodesIndex(const CNodeIndexCounterMap &others) { | |||
| auto it = others.begin(); | |||
| for (; it != others.end(); it++) { | |||
| // Ignore the user graph who may own itself. | |||
| if (it->first->first->func_graph().get() != this) { | |||
| AddFuncGraphCNodeIndex(it->first, it->second); | |||
| } | |||
| } | |||
| } | |||
| void FuncGraph::ClearFuncGraphCNodesIndex() { func_graph_cnodes_index_.clear(); } | |||
| void FuncGraph::AddFuncGraphCNodeIndex(CNodeIndexPairPtr pair, int count) { | |||
| if (func_graph_cnodes_index_.count(pair) == 0) { | |||
| func_graph_cnodes_index_[pair] = count; | |||
| } else { | |||
| func_graph_cnodes_index_[pair] += count; | |||
| } | |||
| } | |||
| void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) { | |||
| if (func_graph_cnodes_index_.count(pair) != 0) { | |||
| if (func_graph_cnodes_index_[pair] == 1) { | |||
| (void)func_graph_cnodes_index_.erase(pair); | |||
| } else { | |||
| func_graph_cnodes_index_[pair]--; | |||
| if (func_graph_cnodes_index_[pair] < 0) { | |||
| MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second | |||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| const AnfNodeCounterMap &FuncGraph::j_func_graph_value_nodes() { return j_func_graph_value_nodes_; } | |||
| void FuncGraph::CopyJFuncGraphValueNodes(const AnfNodeCounterMap &others) { j_func_graph_value_nodes_ = others; } | |||
| void FuncGraph::ClearJFuncGraphValueNodes() { j_func_graph_value_nodes_.clear(); } | |||
| void FuncGraph::AddJFuncGraphValueNode(AnfNodePtr node, int count) { | |||
| if (j_func_graph_value_nodes_.count(node) == 0) { | |||
| j_func_graph_value_nodes_[node] = count; | |||
| } else { | |||
| j_func_graph_value_nodes_[node] += count; | |||
| } | |||
| } | |||
| void FuncGraph::DropJFuncGraphValueNode(AnfNodePtr node) { | |||
| if (j_func_graph_value_nodes_.count(node) != 0) { | |||
| if (j_func_graph_value_nodes_[node] == 1) { | |||
| (void)j_func_graph_value_nodes_.erase(node); | |||
| } else { | |||
| j_func_graph_value_nodes_[node]--; | |||
| if (j_func_graph_value_nodes_[node] < 0) { | |||
| MS_LOG(EXCEPTION) << "Count of value node(J FuncGraph) '" << node | |||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); | |||
| } | |||
| } | |||
| } | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto &cnode = mng->func_graph_cnodes_index(); | |||
| return cnode[shared_from_base<FuncGraph>()]; | |||
| } | |||
| FuncGraphPtr FuncGraph::parent() { | |||
| @@ -662,10 +803,10 @@ void FuncGraph::EraseUnusedNodeInOrder() { | |||
| if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | |||
| auto mng = manager_.lock(); | |||
| if (mng) { | |||
| auto nodes = mng->nodes()[shared_from_base<FuncGraph>()]; | |||
| auto &all_nodes = nodes(); | |||
| // Erase unused cnode. | |||
| for (auto it = order_.begin(); it != order_.end();) { | |||
| if (nodes.count(*it)) { | |||
| if (all_nodes.count(*it)) { | |||
| (void)it++; | |||
| } else { | |||
| MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order."; | |||
| @@ -702,11 +843,11 @@ void FuncGraph::CheckOrder() { | |||
| } | |||
| auto mng = manager_.lock(); | |||
| if (mng != nullptr) { | |||
| const auto &nodes = mng->nodes()[shared_from_base<FuncGraph>()]; | |||
| if (nodes.size() != (order_.size() + parameters_.size())) { | |||
| const auto &all_nodes = nodes(); | |||
| if (all_nodes.size() != (order_.size() + parameters_.size())) { | |||
| DumpCNodeList(); | |||
| MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " | |||
| << nodes.size() - parameters_.size() << "."; | |||
| << all_nodes.size() - parameters_.size() << "."; | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "Check order okay."; | |||
| @@ -26,6 +26,7 @@ | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include <functional> | |||
| #include "ir/anf.h" | |||
| #include "ir/manager.h" | |||
| @@ -36,8 +37,13 @@ | |||
| namespace mindspore { | |||
| using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>; | |||
| using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>; | |||
| using AnfNodeCounterMap = OrderedMap<AnfNodePtr, int>; | |||
| using CNodeIndexCounterMap = OrderedMap<CNodeIndexPairPtr, int, CNodeIndexHasher, CNodeIndexEqual>; | |||
| template <typename ValueT, class CounterHash = std::hash<ValueT>, class CounterEqual = std::equal_to<ValueT>> | |||
| using CounterOrderedMap = OrderedMap<ValueT, int, CounterHash, CounterEqual>; | |||
| using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>; | |||
| using CNodeIndexCounterMap = CounterOrderedMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual>; | |||
| using FuncGraphMap = OrderedMap<FuncGraphPtr, int>; | |||
| const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; | |||
| const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; | |||
| @@ -183,12 +189,24 @@ class FuncGraph : public FuncGraphBase { | |||
| // get all nodes belonging to this func graph | |||
| const AnfNodeSet &nodes(); | |||
| void CopyNodes(const AnfNodeSet &other_nodes); | |||
| void ClearNodes(); | |||
| void AddNode(AnfNodePtr node); | |||
| void DropNode(AnfNodePtr node); | |||
| // get all value_nodes belonging to this func graph | |||
| const AnfNodeCounterMap &value_nodes(); | |||
| // get all vars directly pointed to in this func graph | |||
| const AnfNodeCounterMap &free_variables_direct(); | |||
| void CopyValueNodes(const AnfNodeCounterMap &other_value_nodes); | |||
| void ClearValueNodes(); | |||
| void AddValueNode(AnfNodePtr node, int count = 1); | |||
| void DropValueNode(AnfNodePtr node); | |||
| // get all free vars directly used in this func graph | |||
| const AnfNodeCounterMap &free_variables(); | |||
| void CopyFreeVariables(const AnfNodeCounterMap &others); | |||
| void ClearFreeVariables(); | |||
| bool AddFreeVariable(AnfNodePtr node, int count = 1); | |||
| bool DropFreeVariable(AnfNodePtr node); | |||
| // get all vars required by this func graph | |||
| const BaseRefCounterMap &free_variables_total(); | |||
| @@ -199,14 +217,29 @@ class FuncGraph : public FuncGraphBase { | |||
| // get all vars that are func graphs | |||
| std::vector<FuncGraphPtr> free_variables_func_graphs(); | |||
| // get all func graphs directly used by this func graph | |||
| const FuncGraphCounterMap &func_graphs_used(); | |||
| // get all value nodes of func graph directly used by this func graph | |||
| const AnfNodeCounterMap &func_graph_value_nodes(); | |||
| void CopyFuncGraphValueNodes(const AnfNodeCounterMap &others); | |||
| void ClearFuncGraphValueNodes(); | |||
| bool AddFuncGraphValueNode(AnfNodePtr node, int count = 1); | |||
| bool DropFuncGraphValueNode(AnfNodePtr node); | |||
| // get all value nodes of J func graph directly used by this func graph | |||
| const AnfNodeCounterMap &j_func_graph_value_nodes(); | |||
| void CopyJFuncGraphValueNodes(const AnfNodeCounterMap &others); | |||
| void ClearJFuncGraphValueNodes(); | |||
| void AddJFuncGraphValueNode(AnfNodePtr node, int count = 1); | |||
| void DropJFuncGraphValueNode(AnfNodePtr node); | |||
| // get all func graphs nested used by this func graph | |||
| const FuncGraphSet &func_graphs_used_total(); | |||
| // get all user value nodes of this func graph | |||
| // get all user value nodes of this func graph, by CNode and its input's index | |||
| const CNodeIndexCounterMap &func_graph_cnodes_index(); | |||
| void CopyFuncGraphCNodesIndex(const CNodeIndexCounterMap &other_value_nodes); | |||
| void ClearFuncGraphCNodesIndex(); | |||
| void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1); | |||
| void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node); | |||
| // Return the parent of this graph. | |||
| FuncGraphPtr parent(); | |||
| @@ -270,6 +303,24 @@ class FuncGraph : public FuncGraphBase { | |||
| // graph is manipulated by manager and others | |||
| friend FuncGraphManager; | |||
| // all nodes of the function | |||
| AnfNodeSet nodes_; | |||
| // all value nodes of the function | |||
| AnfNodeCounterMap value_nodes_; | |||
| // all func graph value nodes of the function | |||
| AnfNodeCounterMap func_graph_value_nodes_; | |||
| // all free variables of the function | |||
| AnfNodeCounterMap free_variables_; | |||
| // all value nodes calling J in the function | |||
| AnfNodeCounterMap j_func_graph_value_nodes_; | |||
| // all user value nodes of this func graph, recording by CNode and its input's index | |||
| CNodeIndexCounterMap func_graph_cnodes_index_; | |||
| // parameters of this function | |||
| std::vector<AnfNodePtr> parameters_; | |||
| std::vector<AnfNodePtr> paramter_obj_nodes_; | |||
| @@ -123,7 +123,7 @@ void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) { | |||
| if (!clone_all_valuenodes_) { | |||
| return; | |||
| } | |||
| auto &value_nodes = manager_->valuenodes()[func_graph]; | |||
| auto &value_nodes = func_graph->value_nodes(); | |||
| for (auto &value_node : value_nodes) { | |||
| auto old_node = value_node.first; | |||
| MS_EXCEPTION_IF_NULL(old_node); | |||
| @@ -153,9 +153,9 @@ void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) { | |||
| if (!clone_all_used_graphs_) { | |||
| return; | |||
| } | |||
| auto &used_graphs = manager_->func_graphs_used()[func_graph]; | |||
| for (auto &used_graph : used_graphs) { | |||
| todo_.push_back({used_graph.first, nullptr, {}}); | |||
| auto &used = func_graph->func_graph_value_nodes(); | |||
| for (auto &fg_value_node : used) { | |||
| todo_.push_back({GetValueNode<FuncGraphPtr>(fg_value_node.first), nullptr, {}}); | |||
| } | |||
| } | |||
| @@ -185,7 +185,7 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func | |||
| } | |||
| target_func_graph->set_return(return_node); | |||
| auto &cnodes = manager_->func_graph_cnodes_index()[func_graph]; | |||
| auto &cnodes = func_graph->func_graph_cnodes_index(); | |||
| for (auto &cnode : cnodes) { | |||
| auto parent = cnode.first->first->cast<CNodePtr>(); | |||
| auto valuenode = parent->input(cnode.first->second); | |||
| @@ -441,7 +441,7 @@ void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &t | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(target_func_graph); | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| const AnfNodeSet &nodes = manager_->nodes()[func_graph]; | |||
| const AnfNodeSet &nodes = func_graph->nodes(); | |||
| for (auto &node : nodes) { | |||
| CloneNode(node, target_func_graph); | |||
| } | |||
| @@ -78,19 +78,6 @@ void FuncGraphManager::Reset() { | |||
| node_users_ = NodeUsersMap(); | |||
| signals_ = std::make_shared<Signals>(); | |||
| // FuncGraph --> AnfNode | |||
| nodes_ = std::make_shared<NodesCollector>(this); | |||
| // FuncGraph --> {AnfNode, Count} | |||
| valuenodes_ = std::make_shared<ValueNodesCollector>(this); | |||
| free_variables_direct_ = std::make_shared<FVDirectCollector>(this); | |||
| func_graph_cnodes_index_ = std::make_shared<FuncGraphUsersCNodeIndexCollector>(this); | |||
| // FuncGraph --> {FuncGraph, Count} | |||
| func_graphs_used_ = std::make_shared<FuncGraphsUsedCollector>(this); | |||
| func_graph_child_direct_ = std::make_shared<FuncGraphChildDirect>(this); | |||
| func_graph_parents_direct_ = std::make_shared<FuncGraphParentsDirectCollector>(this); | |||
| func_graph_j_direct_ = std::make_shared<FuncGraphJDirectCollector>(this); | |||
| func_graph_parents_total_ = std::make_shared<FuncGraphParentsTotalComputer>(this); | |||
| func_graph_parent_ = std::make_shared<ParentComputer>(this); | |||
| @@ -210,7 +197,6 @@ void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) { | |||
| } | |||
| AddIntoManaged(func_graph); | |||
| MS_EXCEPTION_IF_NULL(signals_); | |||
| signals_->AddFuncGraph(func_graph); | |||
| std::vector<AnfNodePtr> para = func_graph->parameters(); | |||
| AcquireNodes(para); | |||
| std::vector<AnfNodePtr> return_vec({func_graph->get_return()}); | |||
| @@ -224,7 +210,6 @@ void FuncGraphManager::Clear() { | |||
| node_users_.clear(); | |||
| roots_.clear(); | |||
| signals_->InvalidateCollector(); | |||
| signals_->InvalidateComputer(); | |||
| } | |||
| @@ -303,8 +288,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool | |||
| MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString(); | |||
| continue; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(func_graph_cnodes_index_); | |||
| auto &users_cnode_index = func_graph_cnodes_index_->count_nodes_map()[func_graph]; | |||
| auto &users_cnode_index = func_graph->func_graph_cnodes_index(); | |||
| if (!users_cnode_index.empty() && !ignore_users) { | |||
| MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); | |||
| continue; | |||
| @@ -320,7 +304,6 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool | |||
| MS_EXCEPTION_IF_NULL(signals_); | |||
| for (auto &fg : dropped) { | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| signals_->DropFuncGraph(fg); | |||
| all_nodes_.difference_update(fg->parameters()); | |||
| (void)func_graphs_.erase(fg); | |||
| if (fg->manager().get() == this) { | |||
| @@ -339,7 +322,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E | |||
| return; | |||
| } | |||
| (void)users_node.erase(make_pair(node, index)); | |||
| signals_->DropEdge(node, index, inp); | |||
| DropEdge(node, index, inp); | |||
| } else { | |||
| MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString(); | |||
| if (inp->func_graph() != nullptr) { | |||
| @@ -352,7 +335,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E | |||
| auto &users_node = node_users_[inp]; | |||
| users_node.add(make_pair(node, index)); | |||
| MS_EXCEPTION_IF_NULL(signals_); | |||
| signals_->AddEdge(node, index, inp); | |||
| AddEdge(node, index, inp); | |||
| } | |||
| } | |||
| @@ -392,8 +375,8 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) { | |||
| FuncGraphPtr fg = node->func_graph(); | |||
| if (fg != nullptr) { | |||
| AddFuncGraph(fg); | |||
| fg->AddNode(node); | |||
| } | |||
| signals_->AddNode(node); | |||
| ProcessInputs(node, kIncEdge); | |||
| } | |||
| } | |||
| @@ -424,7 +407,10 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> & | |||
| } | |||
| ProcessInputs(node, kDecEdge); | |||
| (void)all_nodes_.erase(node); | |||
| signals_->DropNode(node); | |||
| if (node->func_graph() != nullptr) { | |||
| node->func_graph()->DropNode(node); | |||
| } | |||
| if (node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| nodes_ordered.update(cnode->inputs()); | |||
| @@ -462,35 +448,21 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t | |||
| int index = 0; | |||
| (void)node_users_[source_prim].erase(make_pair(source_return, index)); | |||
| signals_->DropEdge(source_return, index, source_prim); | |||
| DropEdge(source_return, index, source_prim); | |||
| index = 1; | |||
| (void)node_users_[source_output].erase(make_pair(source_return, index)); | |||
| signals_->DropEdge(source_return, index, source_output); | |||
| DropEdge(source_return, index, source_output); | |||
| (void)all_nodes_.erase(source_return); | |||
| (void)node_users_.erase(source_return); | |||
| signals_->DropNode(source_return); | |||
| source->DropNode(source_return); | |||
| for (auto &node : source->nodes()) { | |||
| node->set_func_graph(target); | |||
| if (node->scope() == kDefaultScope) { | |||
| node->set_scope(scope); | |||
| } | |||
| } | |||
| for (auto &child : this->func_graph_child_direct()[source]) { | |||
| (void)func_graph_parents_direct_->Inc(child.first, target, child.second); | |||
| (void)this->func_graph_parents_direct()[child.first].erase(source); | |||
| } | |||
| for (auto &fv_count : this->free_variables_direct()[source]) { | |||
| auto fv_g = fv_count.first->func_graph(); | |||
| auto &count_on_g = this->func_graph_child_direct()[fv_g]; | |||
| auto pair = count_on_g.find(source); | |||
| if (fv_g != target && pair != count_on_g.end()) { | |||
| (void)func_graph_child_direct_->Inc(fv_g, target, pair->second); | |||
| } | |||
| (void)count_on_g.erase(source); | |||
| } | |||
| signals_->MoveAllCNode(source, target); | |||
| signals_->InvalidateComputer(); | |||
| signals_->DropFuncGraph(source); | |||
| MoveAllNodes(source, target); | |||
| all_nodes_.difference_update(source->parameters()); | |||
| (void)func_graphs_.erase(source); | |||
| if (source->manager().get() == this) { | |||
| @@ -498,6 +470,64 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t | |||
| } | |||
| } | |||
| inline void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr input) { | |||
| auto fg = node->func_graph(); | |||
| if (input->isa<ValueNode>()) { | |||
| fg->AddValueNode(input); | |||
| if (IsValueNode<FuncGraph>(input)) { | |||
| if (fg->AddFuncGraphValueNode(input)) { | |||
| signals_->InvalidateComputer(); | |||
| } | |||
| auto used = GetValueNode<FuncGraphPtr>(input); | |||
| used->AddFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index))); | |||
| if (IsPrimitiveCNode(node, prim::kPrimJ)) { | |||
| fg->AddJFuncGraphValueNode(input); | |||
| } | |||
| } | |||
| } else if (fg != nullptr && fg != input->func_graph()) { | |||
| if (fg->AddFreeVariable(input)) { | |||
| signals_->InvalidateComputer(); | |||
| } | |||
| } | |||
| } | |||
| inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr input) { | |||
| auto fg = node->func_graph(); | |||
| if (input->isa<ValueNode>()) { | |||
| fg->DropValueNode(input); | |||
| if (IsValueNode<FuncGraph>(input)) { | |||
| if (fg->DropFuncGraphValueNode(input)) { | |||
| signals_->InvalidateComputer(); | |||
| } | |||
| auto used = GetValueNode<FuncGraphPtr>(input); | |||
| used->DropFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index))); | |||
| if (IsPrimitiveCNode(node, prim::kPrimJ)) { | |||
| fg->DropJFuncGraphValueNode(input); | |||
| } | |||
| } | |||
| } else if (fg != nullptr && fg != input->func_graph()) { | |||
| if (fg->DropFreeVariable(input)) { | |||
| signals_->InvalidateComputer(); | |||
| } | |||
| } | |||
| } | |||
| inline void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { | |||
| target->CopyNodes(source->nodes()); | |||
| target->CopyValueNodes(source->value_nodes()); | |||
| target->CopyFuncGraphCNodesIndex(source->func_graph_cnodes_index()); | |||
| target->CopyFreeVariables(source->free_variables()); | |||
| target->CopyFuncGraphValueNodes(source->func_graph_value_nodes()); | |||
| target->CopyJFuncGraphValueNodes(source->j_func_graph_value_nodes()); | |||
| signals_->InvalidateComputer(); | |||
| source->ClearNodes(); | |||
| source->ClearValueNodes(); | |||
| source->ClearFuncGraphCNodesIndex(); | |||
| source->ClearFreeVariables(); | |||
| source->ClearFuncGraphValueNodes(); | |||
| source->ClearJFuncGraphValueNodes(); | |||
| } | |||
| FuncGraphTransaction FuncGraphManager::Transact() { | |||
| auto tr = FuncGraphTransaction(this); | |||
| return tr; | |||
| @@ -630,7 +660,6 @@ void NodesCollector::OnAddNode(AnfNodePtr n) { | |||
| if (nodes_analysis_.find(n->func_graph()) == nodes_analysis_.end()) { | |||
| nodes_analysis_[n->func_graph()] = AnfNodeSet(); | |||
| } | |||
| nodes_analysis_[n->func_graph()].add(n); | |||
| } | |||
| @@ -910,17 +939,19 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f | |||
| return std::make_shared<FuncGraphSet>(); | |||
| } | |||
| FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>(); | |||
| FuncGraphToFuncGraphCounterMap &deps = *all_parents_direct_; | |||
| for (auto &dep : deps[fg]) { | |||
| MS_EXCEPTION_IF_NULL(dep.first); | |||
| auto proxy = dep.first->transforms().find("proxy"); | |||
| if (proxy != dep.first->transforms().end()) { | |||
| path->add(fg); | |||
| auto gt = proxy->second.func_graph(); | |||
| parents->update(SeekParents(gt, path)); | |||
| } else { | |||
| parents->add(dep.first); | |||
| } | |||
| // Append all the fvs in fg. | |||
| auto &fvs = fg->free_variables(); | |||
| for (auto fv : fvs) { | |||
| parents->add(fv.first->func_graph()); | |||
| } | |||
| // Search the fv in fg's child func graph. | |||
| auto &fg_value_nodes = fg->func_graph_value_nodes(); | |||
| for (auto &fg_value_node : fg_value_nodes) { | |||
| path->add(fg); | |||
| auto gt = GetValueNode<FuncGraphPtr>(fg_value_node.first); | |||
| parents->update(SeekParents(gt, path)); | |||
| } | |||
| (void)parents->erase(fg); | |||
| return parents; | |||
| @@ -928,10 +959,7 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f | |||
| void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| all_parents_direct_ = &(manager_->func_graph_parents_direct()); | |||
| MS_LOG(DEBUG) << fg->ToString() << " total func graph dep size:" << (*all_parents_direct_)[fg].size(); | |||
| func_graph_parents_total_analysis_[fg].update(SeekParents(fg)); | |||
| MS_LOG(DEBUG) << "FuncGraphParentsTotalComputer end: " << func_graph_parents_total_analysis_[fg].size(); | |||
| } | |||
| bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) { | |||
| @@ -1001,28 +1029,30 @@ void FVTotalComputer::RealRecompute() { | |||
| } | |||
| for (auto &fg : manager->func_graphs()) { | |||
| AnfNodeCounterMap items = manager->free_variables_direct()[fg]; | |||
| AnfNodeCounterMap items = fg->free_variables(); | |||
| for (auto &iter : items) { | |||
| auto curr = fg; | |||
| while (curr) { | |||
| while (curr != nullptr) { | |||
| (void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second); | |||
| curr = manager->parent(curr); | |||
| const AnfNodeSet &nodes = manager->nodes()[curr]; | |||
| if (nodes.contains(iter.first)) { | |||
| break; | |||
| if (curr != nullptr) { | |||
| const AnfNodeSet &all_nodes = curr->nodes(); | |||
| if (all_nodes.contains(iter.first)) { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| auto items_fg = manager->func_graphs_used()[fg]; | |||
| for (auto &iter : items_fg) { | |||
| auto p = manager->parent(iter.first); | |||
| auto &used = fg->func_graph_value_nodes(); | |||
| for (auto &iter : used) { | |||
| auto p = manager->parent(GetValueNode<FuncGraphPtr>(iter.first)); | |||
| if (p == nullptr) { | |||
| continue; | |||
| } | |||
| auto curr = fg; | |||
| while (curr != p) { | |||
| (void)CounterFuncGraphCollector::Mod(curr, iter.first, iter.second); | |||
| (void)CounterFuncGraphCollector::Mod(curr, GetValueNode<FuncGraphPtr>(iter.first), iter.second); | |||
| curr = manager->parent(curr); | |||
| } | |||
| } | |||
| @@ -1041,7 +1071,6 @@ void FVTotalComputer::RealRecompute() { | |||
| void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| auto &used = this->manager_->func_graphs_used(); | |||
| std::vector<FuncGraphPtr> todo; | |||
| std::vector<FuncGraphPtr> todo_new; | |||
| @@ -1049,8 +1078,8 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { | |||
| while (!todo.empty()) { | |||
| todo_new.clear(); | |||
| for (auto > : todo) { | |||
| for (auto &item : used[gt]) { | |||
| auto used_fg = item.first; | |||
| for (auto &item : gt->func_graph_value_nodes()) { | |||
| auto used_fg = GetValueNode<FuncGraphPtr>(item.first); | |||
| if (used_fg == fg) { | |||
| func_graph_used_total_analysis_[fg].add(used_fg); | |||
| continue; | |||
| @@ -1068,7 +1097,6 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { | |||
| bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) { | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto &used = manager->func_graphs_used(); | |||
| std::vector<FuncGraphPtr> todo; | |||
| std::vector<FuncGraphPtr> todo_new; | |||
| todo.push_back(fg); | |||
| @@ -1076,8 +1104,8 @@ bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &f | |||
| while (!todo.empty()) { | |||
| todo_new.clear(); | |||
| for (auto > : todo) { | |||
| for (auto &item : used[gt]) { | |||
| auto used_g = item.first; | |||
| for (auto &item : gt->func_graph_value_nodes()) { | |||
| auto used_g = GetValueNode<FuncGraphPtr>(item.first); | |||
| if (used_g == fg) { | |||
| return true; | |||
| } | |||
| @@ -1108,9 +1136,9 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F | |||
| } | |||
| } else { | |||
| trace->push_back(fg); | |||
| auto &used_fgs = manager_->func_graphs_used()[fg]; | |||
| for (auto iter = used_fgs.begin(); iter != used_fgs.end(); (void)iter++) { | |||
| CheckRecursiveGraphs(iter->first, trace); | |||
| auto &items = fg->func_graph_value_nodes(); | |||
| for (auto iter = items.begin(); iter != items.end(); (void)iter++) { | |||
| CheckRecursiveGraphs(GetValueNode<FuncGraphPtr>(iter->first), trace); | |||
| } | |||
| trace->pop_back(); | |||
| if (!recursive_map_.count(fg)) { | |||
| @@ -1125,14 +1153,13 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt | |||
| MS_LOG(DEBUG) << fg->ToString() << " had been checked"; | |||
| return false; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| auto &func_graph_counter_map = manager_->func_graph_j_direct(); | |||
| if (!func_graph_counter_map[fg].empty()) { | |||
| auto &j_fg_value_nodes = fg->j_func_graph_value_nodes(); | |||
| if (!j_fg_value_nodes.empty()) { | |||
| // check g1->J(fg)->g2->g cycle; | |||
| auto contains_j = | |||
| std::find_if(func_graph_counter_map[fg].begin(), func_graph_counter_map[fg].end(), | |||
| [path](const std::pair<FuncGraphPtr, int> iter) { return !path->contains(iter.first); }); | |||
| if (contains_j != func_graph_counter_map[fg].end()) { | |||
| std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(), | |||
| [path](const std::pair<AnfNodePtr, int> iter) { return !path->contains(GetValueNode<FuncGraphPtr>(iter.first)); }); | |||
| if (contains_j != j_fg_value_nodes.end()) { | |||
| MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")"; | |||
| return true; | |||
| } | |||
| @@ -1140,9 +1167,8 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt | |||
| path->add(fg); | |||
| // check if func graphs used contains J(func_graph); | |||
| auto &used = this->manager_->func_graphs_used(); | |||
| for (auto &item : used[fg]) { | |||
| auto used_g = item.first; | |||
| for (auto &item : fg->func_graph_value_nodes()) { | |||
| auto used_g = GetValueNode<FuncGraphPtr>(item.first); | |||
| if (SeekJ(used_g, path)) { | |||
| MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)"; | |||
| return true; | |||
| @@ -367,8 +367,8 @@ class DepComputer : public FuncGraphAnalysis { | |||
| // graph g's all direct or proxy parents | |||
| class FuncGraphParentsTotalComputer final : public DepComputer { | |||
| public: | |||
| explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m), all_parents_direct_(nullptr) {} | |||
| ~FuncGraphParentsTotalComputer() override { all_parents_direct_ = nullptr; } | |||
| explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m) {} | |||
| ~FuncGraphParentsTotalComputer() override = default; | |||
| FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; } | |||
| @@ -383,9 +383,6 @@ class FuncGraphParentsTotalComputer final : public DepComputer { | |||
| private: | |||
| FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared<FuncGraphSet>()); | |||
| // when SeekParents calls itself recursively, it can access these variables by class member | |||
| // other than pass by formal parameters, it can save 1 parameter for SeekParents(). | |||
| FuncGraphToFuncGraphCounterMap *all_parents_direct_; | |||
| }; | |||
| using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>; | |||
| @@ -562,30 +559,6 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||
| NodeUsersMap &node_users() { return node_users_; } | |||
| FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; } | |||
| FuncGraphToAnfNodeCounterMap<AnfNodePtr> &valuenodes() const { return valuenodes_->count_nodes_map_; } | |||
| FuncGraphToAnfNodeCounterMap<AnfNodePtr> &free_variables_direct() const { | |||
| return free_variables_direct_->count_nodes_map_; | |||
| } | |||
| FuncGraphToAnfNodeCounterMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> &func_graph_cnodes_index() const { | |||
| return func_graph_cnodes_index_->count_nodes_map_; | |||
| } | |||
| FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } | |||
| FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const { | |||
| return func_graph_child_direct_->count_func_graphs_map_; | |||
| } | |||
| FuncGraphToFuncGraphCounterMap &func_graph_parents_direct() const { | |||
| return func_graph_parents_direct_->count_func_graphs_map_; | |||
| } | |||
| FuncGraphToFuncGraphCounterMap &func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; } | |||
| FVTotalMap &free_variables_total() const; | |||
| FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const; | |||
| @@ -610,14 +583,6 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||
| // Static Analysis | |||
| NodeUsersMap node_users_; | |||
| AnfNodeSet all_nodes_; // managed nodes | |||
| std::shared_ptr<NodesCollector> nodes_; | |||
| std::shared_ptr<ValueNodesCollector> valuenodes_; | |||
| std::shared_ptr<FVDirectCollector> free_variables_direct_; | |||
| std::shared_ptr<FuncGraphUsersCNodeIndexCollector> func_graph_cnodes_index_; | |||
| std::shared_ptr<FuncGraphsUsedCollector> func_graphs_used_; | |||
| std::shared_ptr<FuncGraphChildDirect> func_graph_child_direct_; | |||
| std::shared_ptr<FuncGraphParentsDirectCollector> func_graph_parents_direct_; | |||
| std::shared_ptr<FuncGraphJDirectCollector> func_graph_j_direct_; | |||
| // Dynamic Analysis | |||
| std::shared_ptr<ParentComputer> func_graph_parent_; | |||
| @@ -630,6 +595,9 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||
| FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr> &nodes); | |||
| void ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges, EdgeTupleCounter *rm_edges, | |||
| Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms); | |||
| void AddEdge(AnfNodePtr node, int index, AnfNodePtr input); | |||
| void DropEdge(AnfNodePtr node, int index, AnfNodePtr input); | |||
| void MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target); | |||
| FuncGraphSet roots_; // managed roots | |||
| FuncGraphSet func_graphs_; // managed func graphs | |||
| @@ -491,7 +491,7 @@ void DFunctor::MapParamObject() { | |||
| void DFunctor::MapValueObject() { | |||
| // Map ValueNode. | |||
| auto manager = resources_->manager(); | |||
| auto &value_nodes = manager->valuenodes()[primal_graph_]; | |||
| auto &value_nodes = primal_graph_->value_nodes(); | |||
| for (const auto &value_pair : value_nodes) { | |||
| auto node = value_pair.first; | |||
| auto parent_adjoint = FindAdjoint(node); | |||
| @@ -119,7 +119,7 @@ FuncGraphPtr TransformGraphCondBranchNodes( | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node; | |||
| // record the node input to be replaced | |||
| NodeInputReplMap repl_node_inputs; | |||
| const AnfNodeSet &nodes = manager->nodes()[graph]; | |||
| const AnfNodeSet &nodes = graph->nodes(); | |||
| for (auto &node : nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| @@ -436,7 +436,7 @@ FuncGraphPtr TransformGraphDependNode( | |||
| ResetSharedOp(); | |||
| std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> repl_node = | |||
| std::make_shared<std::unordered_map<AnfNodePtr, AnfNodePtr>>(); // record the node to be replaced | |||
| const AnfNodeSet &nodes = manager->nodes()[graph]; | |||
| const AnfNodeSet &nodes = graph->nodes(); | |||
| for (auto &node : nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| @@ -2187,7 +2187,7 @@ void MarkForwardCNode(const FuncGraphPtr &root) { | |||
| SetForwardFlag(all_nodes); | |||
| } else { | |||
| for (auto &func_graph : graph_set) { | |||
| MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size(); | |||
| MS_LOG(INFO) << "The sub graph size of root is " << root->func_graph_value_nodes().size(); | |||
| auto return_node = func_graph->get_return(); | |||
| MS_EXCEPTION_IF_NULL(return_node); | |||
| auto all_dfs_nodes = DeepLinkedGraphSearch(return_node); | |||
| @@ -389,7 +389,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) { | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| auto manager = res->manager(); | |||
| // Remove duplicated value nodes, due to replace operation, can't use reference. | |||
| auto value_nodes = manager->valuenodes()[func_graph]; | |||
| auto value_nodes = func_graph->value_nodes(); | |||
| HashCache hash_cache; | |||
| HashValue hashes; | |||
| for (const auto &value_pair : value_nodes) { | |||
| @@ -487,12 +487,12 @@ void CompileGraph::AddExternal(const LinConvertResult &result) { | |||
| void TraverseGraphMap( | |||
| const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, | |||
| const FuncGraphToAnfNodeCounterMap<AnfNodePtr> &cts, | |||
| const FuncGraphSet &fgs, | |||
| const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { | |||
| MS_EXCEPTION_IF_NULL(manager_ptr); | |||
| MS_EXCEPTION_IF_NULL(tr); | |||
| for (const auto &ct_graphs : cts) { | |||
| for (const auto &ct_any : ct_graphs.second) { | |||
| for (const auto &fg : fgs) { | |||
| for (const auto &ct_any : fg->value_nodes()) { | |||
| AnfNodePtr const_primitive_node = ct_any.first; | |||
| if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) { | |||
| auto users = manager_ptr->node_users()[const_primitive_node]; | |||
| @@ -552,8 +552,8 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) { | |||
| }; | |||
| FuncGraphTransaction tr = manager_ptr->Transact(); | |||
| auto &cts = manager_ptr->valuenodes(); | |||
| TraverseGraphMap(manager_ptr, &tr, cts, get_prim_graph); | |||
| auto &fgs = manager_ptr->func_graphs(); | |||
| TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph); | |||
| return graph; | |||
| } | |||