From 3ae925115ff64e88d95e65b1a520abf316d719e4 Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Fri, 22 May 2020 11:33:28 +0800 Subject: [PATCH 1/4] Refactoring FuncGraphManager module: Move all info. of nodes and edges from FuncGraphManager into FuncGraph. --- mindspore/ccsrc/ir/base.h | 1 + mindspore/ccsrc/ir/func_graph.cc | 207 +++++++++++++++--- mindspore/ccsrc/ir/func_graph.h | 67 +++++- mindspore/ccsrc/ir/func_graph_cloner.cc | 12 +- mindspore/ccsrc/ir/manager.cc | 192 +++++++++------- mindspore/ccsrc/ir/manager.h | 42 +--- mindspore/ccsrc/optimizer/ad/dfunctor.cc | 2 +- .../ccsrc/optimizer/irpass/branch_culling.cc | 4 +- mindspore/ccsrc/parallel/step_parallel.cc | 2 +- mindspore/ccsrc/pipeline/action.cc | 2 +- mindspore/ccsrc/vm/transform.cc | 10 +- 11 files changed, 364 insertions(+), 177 deletions(-) diff --git a/mindspore/ccsrc/ir/base.h b/mindspore/ccsrc/ir/base.h index 7ccef13876..7dc4145837 100644 --- a/mindspore/ccsrc/ir/base.h +++ b/mindspore/ccsrc/ir/base.h @@ -29,6 +29,7 @@ #include "utils/visible.h" #include "utils/log_adapter.h" #include "utils/ordered_set.h" +#include "utils/ordered_map.h" namespace mindspore { template diff --git a/mindspore/ccsrc/ir/func_graph.cc b/mindspore/ccsrc/ir/func_graph.cc index 40417a33da..4dc84cfa01 100644 --- a/mindspore/ccsrc/ir/func_graph.cc +++ b/mindspore/ccsrc/ir/func_graph.cc @@ -195,25 +195,88 @@ GraphDebugInfoPtr FuncGraph::debug_info() { return this->debug_info_; } -const AnfNodeSet &FuncGraph::nodes() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - auto &nodes = mng->nodes(); - return nodes[shared_from_base()]; +const AnfNodeSet &FuncGraph::nodes() { return nodes_; } + +void FuncGraph::CopyNodes(const AnfNodeSet &other_nodes) { nodes_ = other_nodes; } + +void FuncGraph::ClearNodes() { nodes_.clear(); } + +void FuncGraph::AddNode(AnfNodePtr node) { nodes_.add(node); } + +void FuncGraph::DropNode(AnfNodePtr node) { + nodes_.erase(node); + auto graph = node->func_graph(); + // Remove the node from order list. + if (graph) { + graph->EraseUnusedNodeInOrder(node); + } } -const AnfNodeCounterMap &FuncGraph::value_nodes() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - auto &cts = mng->valuenodes(); - return cts[shared_from_base()]; +const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; } + +void FuncGraph::CopyValueNodes(const AnfNodeCounterMap &other_value_nodes) { value_nodes_ = other_value_nodes; } + +void FuncGraph::ClearValueNodes() { value_nodes_.clear(); } + +void FuncGraph::AddValueNode(AnfNodePtr node, int count) { + if (value_nodes_.count(node) == 0) { + value_nodes_[node] = count; + } else { + value_nodes_[node] += count; + } } -const AnfNodeCounterMap &FuncGraph::free_variables_direct() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - auto &fv_direct = mng->free_variables_direct(); - return fv_direct[shared_from_base()]; +void FuncGraph::DropValueNode(AnfNodePtr node) { + if (value_nodes_.count(node) != 0) { + if (value_nodes_[node] == 1) { + (void)value_nodes_.erase(node); + } else { + value_nodes_[node]--; + if (value_nodes_[node] < 0) { + MS_LOG(EXCEPTION) << "Count of ValueNode '" << node + << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); + } + } + } +} + +const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; } + +void FuncGraph::CopyFreeVariables(const AnfNodeCounterMap &others) { + auto it = others.begin(); + for (; it != others.end(); it++) { + if (it->first->func_graph().get() != this) { + (void)AddFreeVariable(it->first, it->second); + } + } +} + +void FuncGraph::ClearFreeVariables() { free_variables_.clear(); } + +bool FuncGraph::AddFreeVariable(AnfNodePtr node, int count) { + if (free_variables_.count(node) == 0) { + free_variables_[node] = count; + return true; + } else { + free_variables_[node] += count; + return false; + } +} + +bool FuncGraph::DropFreeVariable(AnfNodePtr node) { + if (free_variables_.count(node) != 0) { + if (free_variables_[node] == 1) { + (void)free_variables_.erase(node); + return true; + } else { + free_variables_[node]--; + if (free_variables_[node] < 0) { + MS_LOG(EXCEPTION) << "Count of free variable '" << node + << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); + } + } + } + return false; } const BaseRefCounterMap &FuncGraph::free_variables_total() { @@ -249,11 +312,36 @@ std::vector FuncGraph::free_variables_func_graphs() { return func_graphs; } -const FuncGraphCounterMap &FuncGraph::func_graphs_used() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - auto &used = mng->func_graphs_used(); - return used[shared_from_base()]; +const AnfNodeCounterMap &FuncGraph::func_graph_value_nodes() { return func_graph_value_nodes_; } + +void FuncGraph::CopyFuncGraphValueNodes(const AnfNodeCounterMap &others) { func_graph_value_nodes_ = others; } + +void FuncGraph::ClearFuncGraphValueNodes() { func_graph_value_nodes_.clear(); } + +bool FuncGraph::AddFuncGraphValueNode(AnfNodePtr node, int count) { + if (func_graph_value_nodes_.count(node) == 0) { + func_graph_value_nodes_[node] = count; + return true; + } else { + func_graph_value_nodes_[node] += count; + return false; + } +} + +bool FuncGraph::DropFuncGraphValueNode(AnfNodePtr node) { + if (func_graph_value_nodes_.count(node) != 0) { + if (func_graph_value_nodes_[node] == 1) { + (void)func_graph_value_nodes_.erase(node); + return true; + } else { + func_graph_value_nodes_[node]--; + if (func_graph_value_nodes_[node] < 0) { + MS_LOG(EXCEPTION) << "Count of value node(FuncGraph) '" << node + << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); + } + } + } + return false; } const FuncGraphSet &FuncGraph::func_graphs_used_total() { @@ -263,15 +351,68 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() { return used; } -const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { - auto mng = manager_.lock(); - if (mng == nullptr) { - MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString() - << " NodeInfo: " << trace::GetDebugInfo(debug_info()); +const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; } + +void FuncGraph::CopyFuncGraphCNodesIndex(const CNodeIndexCounterMap &others) { + auto it = others.begin(); + for (; it != others.end(); it++) { + // Ignore the user graph who may own itself. + if (it->first->first->func_graph().get() != this) { + AddFuncGraphCNodeIndex(it->first, it->second); + } + } +} + +void FuncGraph::ClearFuncGraphCNodesIndex() { func_graph_cnodes_index_.clear(); } + +void FuncGraph::AddFuncGraphCNodeIndex(CNodeIndexPairPtr pair, int count) { + if (func_graph_cnodes_index_.count(pair) == 0) { + func_graph_cnodes_index_[pair] = count; + } else { + func_graph_cnodes_index_[pair] += count; + } +} + +void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) { + if (func_graph_cnodes_index_.count(pair) != 0) { + if (func_graph_cnodes_index_[pair] == 1) { + (void)func_graph_cnodes_index_.erase(pair); + } else { + func_graph_cnodes_index_[pair]--; + if (func_graph_cnodes_index_[pair] < 0) { + MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second + << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); + } + } + } +} + +const AnfNodeCounterMap &FuncGraph::j_func_graph_value_nodes() { return j_func_graph_value_nodes_; } + +void FuncGraph::CopyJFuncGraphValueNodes(const AnfNodeCounterMap &others) { j_func_graph_value_nodes_ = others; } + +void FuncGraph::ClearJFuncGraphValueNodes() { j_func_graph_value_nodes_.clear(); } + +void FuncGraph::AddJFuncGraphValueNode(AnfNodePtr node, int count) { + if (j_func_graph_value_nodes_.count(node) == 0) { + j_func_graph_value_nodes_[node] = count; + } else { + j_func_graph_value_nodes_[node] += count; + } +} + +void FuncGraph::DropJFuncGraphValueNode(AnfNodePtr node) { + if (j_func_graph_value_nodes_.count(node) != 0) { + if (j_func_graph_value_nodes_[node] == 1) { + (void)j_func_graph_value_nodes_.erase(node); + } else { + j_func_graph_value_nodes_[node]--; + if (j_func_graph_value_nodes_[node] < 0) { + MS_LOG(EXCEPTION) << "Count of value node(J FuncGraph) '" << node + << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); + } + } } - MS_EXCEPTION_IF_NULL(mng); - auto &cnode = mng->func_graph_cnodes_index(); - return cnode[shared_from_base()]; } FuncGraphPtr FuncGraph::parent() { @@ -662,10 +803,10 @@ void FuncGraph::EraseUnusedNodeInOrder() { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { auto mng = manager_.lock(); if (mng) { - auto nodes = mng->nodes()[shared_from_base()]; + auto &all_nodes = nodes(); // Erase unused cnode. for (auto it = order_.begin(); it != order_.end();) { - if (nodes.count(*it)) { + if (all_nodes.count(*it)) { (void)it++; } else { MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order."; @@ -702,11 +843,11 @@ void FuncGraph::CheckOrder() { } auto mng = manager_.lock(); if (mng != nullptr) { - const auto &nodes = mng->nodes()[shared_from_base()]; - if (nodes.size() != (order_.size() + parameters_.size())) { + const auto &all_nodes = nodes(); + if (all_nodes.size() != (order_.size() + parameters_.size())) { DumpCNodeList(); MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " - << nodes.size() - parameters_.size() << "."; + << all_nodes.size() - parameters_.size() << "."; } } MS_LOG(DEBUG) << "Check order okay."; diff --git a/mindspore/ccsrc/ir/func_graph.h b/mindspore/ccsrc/ir/func_graph.h index 91fea89eb3..02a18a9809 100644 --- a/mindspore/ccsrc/ir/func_graph.h +++ b/mindspore/ccsrc/ir/func_graph.h @@ -26,6 +26,7 @@ #include #include #include +#include #include "ir/anf.h" #include "ir/manager.h" @@ -36,8 +37,13 @@ namespace mindspore { using BaseRefCounterMap = OrderedMap; using FuncGraphCounterMap = OrderedMap; -using AnfNodeCounterMap = OrderedMap; -using CNodeIndexCounterMap = OrderedMap; + +template , class CounterEqual = std::equal_to> +using CounterOrderedMap = OrderedMap; +using AnfNodeCounterMap = CounterOrderedMap; +using CNodeIndexCounterMap = CounterOrderedMap; + +using FuncGraphMap = OrderedMap; const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; @@ -183,12 +189,24 @@ class FuncGraph : public FuncGraphBase { // get all nodes belonging to this func graph const AnfNodeSet &nodes(); + void CopyNodes(const AnfNodeSet &other_nodes); + void ClearNodes(); + void AddNode(AnfNodePtr node); + void DropNode(AnfNodePtr node); // get all value_nodes belonging to this func graph const AnfNodeCounterMap &value_nodes(); - - // get all vars directly pointed to in this func graph - const AnfNodeCounterMap &free_variables_direct(); + void CopyValueNodes(const AnfNodeCounterMap &other_value_nodes); + void ClearValueNodes(); + void AddValueNode(AnfNodePtr node, int count = 1); + void DropValueNode(AnfNodePtr node); + + // get all free vars directly used in this func graph + const AnfNodeCounterMap &free_variables(); + void CopyFreeVariables(const AnfNodeCounterMap &others); + void ClearFreeVariables(); + bool AddFreeVariable(AnfNodePtr node, int count = 1); + bool DropFreeVariable(AnfNodePtr node); // get all vars required by this func graph const BaseRefCounterMap &free_variables_total(); @@ -199,14 +217,29 @@ class FuncGraph : public FuncGraphBase { // get all vars that are func graphs std::vector free_variables_func_graphs(); - // get all func graphs directly used by this func graph - const FuncGraphCounterMap &func_graphs_used(); + // get all value nodes of func graph directly used by this func graph + const AnfNodeCounterMap &func_graph_value_nodes(); + void CopyFuncGraphValueNodes(const AnfNodeCounterMap &others); + void ClearFuncGraphValueNodes(); + bool AddFuncGraphValueNode(AnfNodePtr node, int count = 1); + bool DropFuncGraphValueNode(AnfNodePtr node); + + // get all value nodes of J func graph directly used by this func graph + const AnfNodeCounterMap &j_func_graph_value_nodes(); + void CopyJFuncGraphValueNodes(const AnfNodeCounterMap &others); + void ClearJFuncGraphValueNodes(); + void AddJFuncGraphValueNode(AnfNodePtr node, int count = 1); + void DropJFuncGraphValueNode(AnfNodePtr node); // get all func graphs nested used by this func graph const FuncGraphSet &func_graphs_used_total(); - // get all user value nodes of this func graph + // get all user value nodes of this func graph, by CNode and its input's index const CNodeIndexCounterMap &func_graph_cnodes_index(); + void CopyFuncGraphCNodesIndex(const CNodeIndexCounterMap &other_value_nodes); + void ClearFuncGraphCNodesIndex(); + void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1); + void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node); // Return the parent of this graph. FuncGraphPtr parent(); @@ -270,6 +303,24 @@ class FuncGraph : public FuncGraphBase { // graph is manipulated by manager and others friend FuncGraphManager; + // all nodes of the function + AnfNodeSet nodes_; + + // all value nodes of the function + AnfNodeCounterMap value_nodes_; + + // all func graph value nodes of the function + AnfNodeCounterMap func_graph_value_nodes_; + + // all free variables of the function + AnfNodeCounterMap free_variables_; + + // all value nodes calling J in the function + AnfNodeCounterMap j_func_graph_value_nodes_; + + // all user value nodes of this func graph, recording by CNode and its input's index + CNodeIndexCounterMap func_graph_cnodes_index_; + // parameters of this function std::vector parameters_; std::vector paramter_obj_nodes_; diff --git a/mindspore/ccsrc/ir/func_graph_cloner.cc b/mindspore/ccsrc/ir/func_graph_cloner.cc index ab0a4fb19c..99d7c316e9 100644 --- a/mindspore/ccsrc/ir/func_graph_cloner.cc +++ b/mindspore/ccsrc/ir/func_graph_cloner.cc @@ -123,7 +123,7 @@ void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) { if (!clone_all_valuenodes_) { return; } - auto &value_nodes = manager_->valuenodes()[func_graph]; + auto &value_nodes = func_graph->value_nodes(); for (auto &value_node : value_nodes) { auto old_node = value_node.first; MS_EXCEPTION_IF_NULL(old_node); @@ -153,9 +153,9 @@ void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) { if (!clone_all_used_graphs_) { return; } - auto &used_graphs = manager_->func_graphs_used()[func_graph]; - for (auto &used_graph : used_graphs) { - todo_.push_back({used_graph.first, nullptr, {}}); + auto &used = func_graph->func_graph_value_nodes(); + for (auto &fg_value_node : used) { + todo_.push_back({GetValueNode(fg_value_node.first), nullptr, {}}); } } @@ -185,7 +185,7 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func } target_func_graph->set_return(return_node); - auto &cnodes = manager_->func_graph_cnodes_index()[func_graph]; + auto &cnodes = func_graph->func_graph_cnodes_index(); for (auto &cnode : cnodes) { auto parent = cnode.first->first->cast(); auto valuenode = parent->input(cnode.first->second); @@ -441,7 +441,7 @@ void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &t MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(manager_); - const AnfNodeSet &nodes = manager_->nodes()[func_graph]; + const AnfNodeSet &nodes = func_graph->nodes(); for (auto &node : nodes) { CloneNode(node, target_func_graph); } diff --git a/mindspore/ccsrc/ir/manager.cc b/mindspore/ccsrc/ir/manager.cc index 1ed747eefd..ffa42c9177 100644 --- a/mindspore/ccsrc/ir/manager.cc +++ b/mindspore/ccsrc/ir/manager.cc @@ -78,19 +78,6 @@ void FuncGraphManager::Reset() { node_users_ = NodeUsersMap(); signals_ = std::make_shared(); - // FuncGraph --> AnfNode - nodes_ = std::make_shared(this); - - // FuncGraph --> {AnfNode, Count} - valuenodes_ = std::make_shared(this); - free_variables_direct_ = std::make_shared(this); - func_graph_cnodes_index_ = std::make_shared(this); - - // FuncGraph --> {FuncGraph, Count} - func_graphs_used_ = std::make_shared(this); - func_graph_child_direct_ = std::make_shared(this); - func_graph_parents_direct_ = std::make_shared(this); - func_graph_j_direct_ = std::make_shared(this); func_graph_parents_total_ = std::make_shared(this); func_graph_parent_ = std::make_shared(this); @@ -210,7 +197,6 @@ void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) { } AddIntoManaged(func_graph); MS_EXCEPTION_IF_NULL(signals_); - signals_->AddFuncGraph(func_graph); std::vector para = func_graph->parameters(); AcquireNodes(para); std::vector return_vec({func_graph->get_return()}); @@ -224,7 +210,6 @@ void FuncGraphManager::Clear() { node_users_.clear(); roots_.clear(); - signals_->InvalidateCollector(); signals_->InvalidateComputer(); } @@ -303,8 +288,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString(); continue; } - MS_EXCEPTION_IF_NULL(func_graph_cnodes_index_); - auto &users_cnode_index = func_graph_cnodes_index_->count_nodes_map()[func_graph]; + auto &users_cnode_index = func_graph->func_graph_cnodes_index(); if (!users_cnode_index.empty() && !ignore_users) { MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); continue; @@ -320,7 +304,6 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool MS_EXCEPTION_IF_NULL(signals_); for (auto &fg : dropped) { MS_EXCEPTION_IF_NULL(fg); - signals_->DropFuncGraph(fg); all_nodes_.difference_update(fg->parameters()); (void)func_graphs_.erase(fg); if (fg->manager().get() == this) { @@ -339,7 +322,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E return; } (void)users_node.erase(make_pair(node, index)); - signals_->DropEdge(node, index, inp); + DropEdge(node, index, inp); } else { MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString(); if (inp->func_graph() != nullptr) { @@ -352,7 +335,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E auto &users_node = node_users_[inp]; users_node.add(make_pair(node, index)); MS_EXCEPTION_IF_NULL(signals_); - signals_->AddEdge(node, index, inp); + AddEdge(node, index, inp); } } @@ -392,8 +375,8 @@ void FuncGraphManager::AcquireNodes(const std::vector &nodes) { FuncGraphPtr fg = node->func_graph(); if (fg != nullptr) { AddFuncGraph(fg); + fg->AddNode(node); } - signals_->AddNode(node); ProcessInputs(node, kIncEdge); } } @@ -424,7 +407,10 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector & } ProcessInputs(node, kDecEdge); (void)all_nodes_.erase(node); - signals_->DropNode(node); + if (node->func_graph() != nullptr) { + node->func_graph()->DropNode(node); + } + if (node->isa()) { auto cnode = node->cast(); nodes_ordered.update(cnode->inputs()); @@ -462,35 +448,21 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t int index = 0; (void)node_users_[source_prim].erase(make_pair(source_return, index)); - signals_->DropEdge(source_return, index, source_prim); + DropEdge(source_return, index, source_prim); index = 1; (void)node_users_[source_output].erase(make_pair(source_return, index)); - signals_->DropEdge(source_return, index, source_output); + DropEdge(source_return, index, source_output); (void)all_nodes_.erase(source_return); (void)node_users_.erase(source_return); - signals_->DropNode(source_return); + source->DropNode(source_return); for (auto &node : source->nodes()) { node->set_func_graph(target); if (node->scope() == kDefaultScope) { node->set_scope(scope); } } - for (auto &child : this->func_graph_child_direct()[source]) { - (void)func_graph_parents_direct_->Inc(child.first, target, child.second); - (void)this->func_graph_parents_direct()[child.first].erase(source); - } - for (auto &fv_count : this->free_variables_direct()[source]) { - auto fv_g = fv_count.first->func_graph(); - auto &count_on_g = this->func_graph_child_direct()[fv_g]; - auto pair = count_on_g.find(source); - if (fv_g != target && pair != count_on_g.end()) { - (void)func_graph_child_direct_->Inc(fv_g, target, pair->second); - } - (void)count_on_g.erase(source); - } - signals_->MoveAllCNode(source, target); - signals_->InvalidateComputer(); - signals_->DropFuncGraph(source); + + MoveAllNodes(source, target); all_nodes_.difference_update(source->parameters()); (void)func_graphs_.erase(source); if (source->manager().get() == this) { @@ -498,6 +470,64 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t } } +inline void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr input) { + auto fg = node->func_graph(); + if (input->isa()) { + fg->AddValueNode(input); + if (IsValueNode(input)) { + if (fg->AddFuncGraphValueNode(input)) { + signals_->InvalidateComputer(); + } + auto used = GetValueNode(input); + used->AddFuncGraphCNodeIndex(std::make_shared(std::make_pair(node, index))); + if (IsPrimitiveCNode(node, prim::kPrimJ)) { + fg->AddJFuncGraphValueNode(input); + } + } + } else if (fg != nullptr && fg != input->func_graph()) { + if (fg->AddFreeVariable(input)) { + signals_->InvalidateComputer(); + } + } +} + +inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr input) { + auto fg = node->func_graph(); + if (input->isa()) { + fg->DropValueNode(input); + if (IsValueNode(input)) { + if (fg->DropFuncGraphValueNode(input)) { + signals_->InvalidateComputer(); + } + auto used = GetValueNode(input); + used->DropFuncGraphCNodeIndex(std::make_shared(std::make_pair(node, index))); + if (IsPrimitiveCNode(node, prim::kPrimJ)) { + fg->DropJFuncGraphValueNode(input); + } + } + } else if (fg != nullptr && fg != input->func_graph()) { + if (fg->DropFreeVariable(input)) { + signals_->InvalidateComputer(); + } + } +} + +inline void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { + target->CopyNodes(source->nodes()); + target->CopyValueNodes(source->value_nodes()); + target->CopyFuncGraphCNodesIndex(source->func_graph_cnodes_index()); + target->CopyFreeVariables(source->free_variables()); + target->CopyFuncGraphValueNodes(source->func_graph_value_nodes()); + target->CopyJFuncGraphValueNodes(source->j_func_graph_value_nodes()); + signals_->InvalidateComputer(); + source->ClearNodes(); + source->ClearValueNodes(); + source->ClearFuncGraphCNodesIndex(); + source->ClearFreeVariables(); + source->ClearFuncGraphValueNodes(); + source->ClearJFuncGraphValueNodes(); +} + FuncGraphTransaction FuncGraphManager::Transact() { auto tr = FuncGraphTransaction(this); return tr; @@ -630,7 +660,6 @@ void NodesCollector::OnAddNode(AnfNodePtr n) { if (nodes_analysis_.find(n->func_graph()) == nodes_analysis_.end()) { nodes_analysis_[n->func_graph()] = AnfNodeSet(); } - nodes_analysis_[n->func_graph()].add(n); } @@ -910,17 +939,19 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f return std::make_shared(); } FuncGraphSetPtr parents = std::make_shared(); - FuncGraphToFuncGraphCounterMap &deps = *all_parents_direct_; - for (auto &dep : deps[fg]) { - MS_EXCEPTION_IF_NULL(dep.first); - auto proxy = dep.first->transforms().find("proxy"); - if (proxy != dep.first->transforms().end()) { - path->add(fg); - auto gt = proxy->second.func_graph(); - parents->update(SeekParents(gt, path)); - } else { - parents->add(dep.first); - } + + // Append all the fvs in fg. + auto &fvs = fg->free_variables(); + for (auto fv : fvs) { + parents->add(fv.first->func_graph()); + } + + // Search the fv in fg's child func graph. + auto &fg_value_nodes = fg->func_graph_value_nodes(); + for (auto &fg_value_node : fg_value_nodes) { + path->add(fg); + auto gt = GetValueNode(fg_value_node.first); + parents->update(SeekParents(gt, path)); } (void)parents->erase(fg); return parents; @@ -928,10 +959,7 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { MS_EXCEPTION_IF_NULL(fg); - all_parents_direct_ = &(manager_->func_graph_parents_direct()); - MS_LOG(DEBUG) << fg->ToString() << " total func graph dep size:" << (*all_parents_direct_)[fg].size(); func_graph_parents_total_analysis_[fg].update(SeekParents(fg)); - MS_LOG(DEBUG) << "FuncGraphParentsTotalComputer end: " << func_graph_parents_total_analysis_[fg].size(); } bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) { @@ -1001,28 +1029,30 @@ void FVTotalComputer::RealRecompute() { } for (auto &fg : manager->func_graphs()) { - AnfNodeCounterMap items = manager->free_variables_direct()[fg]; + AnfNodeCounterMap items = fg->free_variables(); for (auto &iter : items) { auto curr = fg; - while (curr) { + while (curr != nullptr) { (void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second); curr = manager->parent(curr); - const AnfNodeSet &nodes = manager->nodes()[curr]; - if (nodes.contains(iter.first)) { - break; + if (curr != nullptr) { + const AnfNodeSet &all_nodes = curr->nodes(); + if (all_nodes.contains(iter.first)) { + break; + } } } } - auto items_fg = manager->func_graphs_used()[fg]; - for (auto &iter : items_fg) { - auto p = manager->parent(iter.first); + auto &used = fg->func_graph_value_nodes(); + for (auto &iter : used) { + auto p = manager->parent(GetValueNode(iter.first)); if (p == nullptr) { continue; } auto curr = fg; while (curr != p) { - (void)CounterFuncGraphCollector::Mod(curr, iter.first, iter.second); + (void)CounterFuncGraphCollector::Mod(curr, GetValueNode(iter.first), iter.second); curr = manager->parent(curr); } } @@ -1041,7 +1071,6 @@ void FVTotalComputer::RealRecompute() { void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { MS_EXCEPTION_IF_NULL(manager_); - auto &used = this->manager_->func_graphs_used(); std::vector todo; std::vector todo_new; @@ -1049,8 +1078,8 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { while (!todo.empty()) { todo_new.clear(); for (auto > : todo) { - for (auto &item : used[gt]) { - auto used_fg = item.first; + for (auto &item : gt->func_graph_value_nodes()) { + auto used_fg = GetValueNode(item.first); if (used_fg == fg) { func_graph_used_total_analysis_[fg].add(used_fg); continue; @@ -1068,7 +1097,6 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) { MS_EXCEPTION_IF_NULL(manager); - auto &used = manager->func_graphs_used(); std::vector todo; std::vector todo_new; todo.push_back(fg); @@ -1076,8 +1104,8 @@ bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &f while (!todo.empty()) { todo_new.clear(); for (auto > : todo) { - for (auto &item : used[gt]) { - auto used_g = item.first; + for (auto &item : gt->func_graph_value_nodes()) { + auto used_g = GetValueNode(item.first); if (used_g == fg) { return true; } @@ -1108,9 +1136,9 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::listpush_back(fg); - auto &used_fgs = manager_->func_graphs_used()[fg]; - for (auto iter = used_fgs.begin(); iter != used_fgs.end(); (void)iter++) { - CheckRecursiveGraphs(iter->first, trace); + auto &items = fg->func_graph_value_nodes(); + for (auto iter = items.begin(); iter != items.end(); (void)iter++) { + CheckRecursiveGraphs(GetValueNode(iter->first), trace); } trace->pop_back(); if (!recursive_map_.count(fg)) { @@ -1125,14 +1153,13 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt MS_LOG(DEBUG) << fg->ToString() << " had been checked"; return false; } - MS_EXCEPTION_IF_NULL(manager_); - auto &func_graph_counter_map = manager_->func_graph_j_direct(); - if (!func_graph_counter_map[fg].empty()) { + auto &j_fg_value_nodes = fg->j_func_graph_value_nodes(); + if (!j_fg_value_nodes.empty()) { // check g1->J(fg)->g2->g cycle; auto contains_j = - std::find_if(func_graph_counter_map[fg].begin(), func_graph_counter_map[fg].end(), - [path](const std::pair iter) { return !path->contains(iter.first); }); - if (contains_j != func_graph_counter_map[fg].end()) { + std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(), + [path](const std::pair iter) { return !path->contains(GetValueNode(iter.first)); }); + if (contains_j != j_fg_value_nodes.end()) { MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")"; return true; } @@ -1140,9 +1167,8 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt path->add(fg); // check if func graphs used contains J(func_graph); - auto &used = this->manager_->func_graphs_used(); - for (auto &item : used[fg]) { - auto used_g = item.first; + for (auto &item : fg->func_graph_value_nodes()) { + auto used_g = GetValueNode(item.first); if (SeekJ(used_g, path)) { MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)"; return true; diff --git a/mindspore/ccsrc/ir/manager.h b/mindspore/ccsrc/ir/manager.h index 7f36b53205..cc4336056e 100644 --- a/mindspore/ccsrc/ir/manager.h +++ b/mindspore/ccsrc/ir/manager.h @@ -367,8 +367,8 @@ class DepComputer : public FuncGraphAnalysis { // graph g's all direct or proxy parents class FuncGraphParentsTotalComputer final : public DepComputer { public: - explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m), all_parents_direct_(nullptr) {} - ~FuncGraphParentsTotalComputer() override { all_parents_direct_ = nullptr; } + explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m) {} + ~FuncGraphParentsTotalComputer() override = default; FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; } @@ -383,9 +383,6 @@ class FuncGraphParentsTotalComputer final : public DepComputer { private: FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared()); - // when SeekParents calls itself recursively, it can access these variables by class member - // other than pass by formal parameters, it can save 1 parameter for SeekParents(). - FuncGraphToFuncGraphCounterMap *all_parents_direct_; }; using FuncGraphToFuncGraphMap = OrderedMap; @@ -562,30 +559,6 @@ class FuncGraphManager : public std::enable_shared_from_this { NodeUsersMap &node_users() { return node_users_; } - FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; } - - FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; } - - FuncGraphToAnfNodeCounterMap &free_variables_direct() const { - return free_variables_direct_->count_nodes_map_; - } - - FuncGraphToAnfNodeCounterMap &func_graph_cnodes_index() const { - return func_graph_cnodes_index_->count_nodes_map_; - } - - FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } - - FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const { - return func_graph_child_direct_->count_func_graphs_map_; - } - - FuncGraphToFuncGraphCounterMap &func_graph_parents_direct() const { - return func_graph_parents_direct_->count_func_graphs_map_; - } - - FuncGraphToFuncGraphCounterMap &func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; } - FVTotalMap &free_variables_total() const; FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const; @@ -610,14 +583,6 @@ class FuncGraphManager : public std::enable_shared_from_this { // Static Analysis NodeUsersMap node_users_; AnfNodeSet all_nodes_; // managed nodes - std::shared_ptr nodes_; - std::shared_ptr valuenodes_; - std::shared_ptr free_variables_direct_; - std::shared_ptr func_graph_cnodes_index_; - std::shared_ptr func_graphs_used_; - std::shared_ptr func_graph_child_direct_; - std::shared_ptr func_graph_parents_direct_; - std::shared_ptr func_graph_j_direct_; // Dynamic Analysis std::shared_ptr func_graph_parent_; @@ -630,6 +595,9 @@ class FuncGraphManager : public std::enable_shared_from_this { FuncGraphSetPtr MaybeDropNodes(const std::vector &nodes); void ParseChanges(const std::vector &changes, EdgeTupleCounter *add_edges, EdgeTupleCounter *rm_edges, Counter *adds, Counter *rms); + void AddEdge(AnfNodePtr node, int index, AnfNodePtr input); + void DropEdge(AnfNodePtr node, int index, AnfNodePtr input); + void MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target); FuncGraphSet roots_; // managed roots FuncGraphSet func_graphs_; // managed func graphs diff --git a/mindspore/ccsrc/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/optimizer/ad/dfunctor.cc index 6f648b5728..9b09c99948 100644 --- a/mindspore/ccsrc/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/optimizer/ad/dfunctor.cc @@ -491,7 +491,7 @@ void DFunctor::MapParamObject() { void DFunctor::MapValueObject() { // Map ValueNode. auto manager = resources_->manager(); - auto &value_nodes = manager->valuenodes()[primal_graph_]; + auto &value_nodes = primal_graph_->value_nodes(); for (const auto &value_pair : value_nodes) { auto node = value_pair.first; auto parent_adjoint = FindAdjoint(node); diff --git a/mindspore/ccsrc/optimizer/irpass/branch_culling.cc b/mindspore/ccsrc/optimizer/irpass/branch_culling.cc index d90b2bd44c..7f92ad9a1e 100644 --- a/mindspore/ccsrc/optimizer/irpass/branch_culling.cc +++ b/mindspore/ccsrc/optimizer/irpass/branch_culling.cc @@ -119,7 +119,7 @@ FuncGraphPtr TransformGraphCondBranchNodes( std::unordered_map repl_node; // record the node input to be replaced NodeInputReplMap repl_node_inputs; - const AnfNodeSet &nodes = manager->nodes()[graph]; + const AnfNodeSet &nodes = graph->nodes(); for (auto &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { @@ -436,7 +436,7 @@ FuncGraphPtr TransformGraphDependNode( ResetSharedOp(); std::shared_ptr> repl_node = std::make_shared>(); // record the node to be replaced - const AnfNodeSet &nodes = manager->nodes()[graph]; + const AnfNodeSet &nodes = graph->nodes(); for (auto &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 6c3b51347f..3b679a473f 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -2187,7 +2187,7 @@ void MarkForwardCNode(const FuncGraphPtr &root) { SetForwardFlag(all_nodes); } else { for (auto &func_graph : graph_set) { - MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size(); + MS_LOG(INFO) << "The sub graph size of root is " << root->func_graph_value_nodes().size(); auto return_node = func_graph->get_return(); MS_EXCEPTION_IF_NULL(return_node); auto all_dfs_nodes = DeepLinkedGraphSearch(return_node); diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index 75f9f76db4..c4bce433b1 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -389,7 +389,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); auto manager = res->manager(); // Remove duplicated value nodes, due to replace operation, can't use reference. - auto value_nodes = manager->valuenodes()[func_graph]; + auto value_nodes = func_graph->value_nodes(); HashCache hash_cache; HashValue hashes; for (const auto &value_pair : value_nodes) { diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index e8b47a4bcd..d2ad52c4c7 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -487,12 +487,12 @@ void CompileGraph::AddExternal(const LinConvertResult &result) { void TraverseGraphMap( const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, - const FuncGraphToAnfNodeCounterMap &cts, + const FuncGraphSet &fgs, const std::function(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { MS_EXCEPTION_IF_NULL(manager_ptr); MS_EXCEPTION_IF_NULL(tr); - for (const auto &ct_graphs : cts) { - for (const auto &ct_any : ct_graphs.second) { + for (const auto &fg : fgs) { + for (const auto &ct_any : fg->value_nodes()) { AnfNodePtr const_primitive_node = ct_any.first; if (const_primitive_node != nullptr && IsValueNode(const_primitive_node)) { auto users = manager_ptr->node_users()[const_primitive_node]; @@ -552,8 +552,8 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) { }; FuncGraphTransaction tr = manager_ptr->Transact(); - auto &cts = manager_ptr->valuenodes(); - TraverseGraphMap(manager_ptr, &tr, cts, get_prim_graph); + auto &fgs = manager_ptr->func_graphs(); + TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph); return graph; } From f31564ce98027834ab73a126b1b4ae4cae336362 Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Fri, 22 May 2020 11:36:26 +0800 Subject: [PATCH 2/4] Remove the useless collectors in manager. --- mindspore/ccsrc/ir/manager.cc | 185 +-------------------------- mindspore/ccsrc/ir/manager.h | 151 ++++------------------ tests/ut/cpp/ir/manager_test.cc | 101 ++------------- tests/ut/cpp/optimizer/cconv_test.cc | 7 +- 4 files changed, 44 insertions(+), 400 deletions(-) diff --git a/mindspore/ccsrc/ir/manager.cc b/mindspore/ccsrc/ir/manager.cc index ffa42c9177..00b31d543e 100644 --- a/mindspore/ccsrc/ir/manager.cc +++ b/mindspore/ccsrc/ir/manager.cc @@ -640,53 +640,14 @@ void FuncGraphTransaction::Commit() { } FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager) - : manager_(manager), include_func_graph_none_(false) { - manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph); - manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph); - manager_->signals()->AddEdge.connect(this, &FuncGraphAnalysis::OnAddEdge); - manager_->signals()->DropEdge.connect(this, &FuncGraphAnalysis::OnDropEdge); - manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode); -} - -NodesCollector::NodesCollector(const FuncGraphManager *const m) : DepCollector(m), nodes_analysis_() { - include_func_graph_none_ = true; - nodes_analysis_[nullptr] = AnfNodeSet(); - - manager_->signals()->AddNode.connect(this, &NodesCollector::OnAddNode); - manager_->signals()->DropNode.connect(this, &NodesCollector::OnDropNode); -} - -void NodesCollector::OnAddNode(AnfNodePtr n) { - if (nodes_analysis_.find(n->func_graph()) == nodes_analysis_.end()) { - nodes_analysis_[n->func_graph()] = AnfNodeSet(); - } - nodes_analysis_[n->func_graph()].add(n); -} - -void NodesCollector::OnDropNode(AnfNodePtr n) { - (void)nodes_analysis_[n->func_graph()].erase(n); - auto graph = n->func_graph(); - // Remove the node from order list. - if (graph) { - graph->EraseUnusedNodeInOrder(n); - } -} - -void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - // change the owner of node except for the src's return node - for (auto &it : nodes_analysis_[src]) { - nodes_analysis_[dst].add(it); - } - (void)nodes_analysis_.erase(src); -} - -void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); } + : manager_(manager), include_func_graph_none_(false) {} DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { MS_EXCEPTION_IF_NULL(manager_); - manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector); } +void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); } + void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } template @@ -735,65 +696,6 @@ bool CounterAnfNodeCollector::Mod(const F } } -void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - if (inp->isa()) { - (void)Mod(node->func_graph(), inp, direction); - } -} - -void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto &it : count_nodes_map_[src]) { - (void)Inc(dst, it.first, it.second); - } - (void)count_nodes_map_.erase(src); -} - -void FuncGraphUsersCNodeIndexCollector::OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, - EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - if (IsValueNode(inp)) { - (void)Mod(GetValueNode(inp), std::make_shared(std::make_pair(node, index)), - direction); - } -} - -void FuncGraphUsersCNodeIndexCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto &it : count_nodes_map_[src]) { - // Ignore the user graph who may own itself. - if (dst != it.first->first->func_graph()) { - (void)Inc(dst, it.first, it.second); - } - } - (void)count_nodes_map_.erase(src); -} - -void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(inp); - FuncGraphPtr fg1 = node->func_graph(); - FuncGraphPtr fg2 = inp->func_graph(); - if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) { - (void)Mod(fg1, inp, direction); - } -} - -void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto &it : count_nodes_map_[src]) { - FuncGraphPtr fg2 = it.first->func_graph(); - if (fg2 != dst) { - (void)Inc(dst, it.first, it.second); - } - } - (void)count_nodes_map_.erase(src); -} - -static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) { - FuncGraphPtr gn = std::make_shared(); - (void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg))); - return gn; -} - bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { auto &d = count_func_graphs_map_[func_graph]; if (d.count(key) == 0) { @@ -833,87 +735,6 @@ bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGr } } -void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(inp); - FuncGraphPtr fg1 = node->func_graph(); - FuncGraphPtr fg2 = inp->func_graph(); - if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) { - (void)Mod(fg2, fg1, direction); - } -} - -void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto &it : count_func_graphs_map_[src]) { - FuncGraphPtr fg = it.first; - if (fg != dst) { - (void)Inc(dst, fg, it.second); - } - } - (void)count_func_graphs_map_.erase(src); -} - -void FuncGraphParentsDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - FuncGraphPtr fg1 = node->func_graph(); - // possible child parent - if (IsValueNode(inp)) { - FuncGraphPtr fg2 = GetValueNode(inp); - if (Mod(fg1, ParentProxy(fg2), direction)) { - manager_->signals()->InvalidateComputer(); - } - } - // from fv - FuncGraphPtr fg2 = inp->func_graph(); - if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) { - // node use fv will in here, fg1's node use fg2's node, so fg1 is child and fg2 is parent - if (Mod(fg1, fg2, direction)) { - manager_->signals()->InvalidateComputer(); - } - } -} - -void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto &it : count_func_graphs_map_[src]) { - if (it.first != dst) { - (void)Inc(dst, it.first, it.second); - } - } - (void)count_func_graphs_map_.erase(src); -} - -void FuncGraphsUsedCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - if (IsValueNode(inp)) { - (void)Mod(node->func_graph(), GetValueNode(inp), direction); - } -} - -void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - // all graph use in src need to change to dst, so meger the to dst use - for (auto &it : count_func_graphs_map_[src]) { - (void)Inc(dst, it.first, it.second); - } - (void)count_func_graphs_map_[dst].erase(src); - (void)count_func_graphs_map_.erase(src); -} - -void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - if (IsValueNode(inp) && IsPrimitiveCNode(node, prim::kPrimJ)) { - (void)Mod(node->func_graph(), GetValueNode(inp), direction); - MS_LOG(DEBUG) << node->func_graph()->ToString() << " users func graph " - << GetValueNode(inp)->ToString() << " which contains J(func_graph), dir: " << direction; - } -} - -void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - // all graph use in src need to change to dst, so meger the to dst use - for (auto &it : count_func_graphs_map_[src]) { - (void)Inc(dst, it.first, it.second); - } - (void)count_func_graphs_map_.erase(src); -} - DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { MS_EXCEPTION_IF_NULL(manager_); manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); diff --git a/mindspore/ccsrc/ir/manager.h b/mindspore/ccsrc/ir/manager.h index cc4336056e..06b2859fea 100644 --- a/mindspore/ccsrc/ir/manager.h +++ b/mindspore/ccsrc/ir/manager.h @@ -140,44 +140,6 @@ class FuncGraphAnalysis { using FuncGraphToAnfNodeMap = OrderedMap; -// graphs analysis which compute in write, read needn't recompute -class DepCollector : public FuncGraphAnalysis { - public: - explicit DepCollector(const FuncGraphManager *manager); - ~DepCollector() override = default; - - void Reset() { ExtraReset(); } - void OnInvalidateCollector() { Reset(); } - - protected: - // inherit from FuncGraphAnalysis - void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override; - void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override; - // subclass can override; - virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {} -}; - -class NodesCollector final : public DepCollector { - public: - explicit NodesCollector(const FuncGraphManager *m); - ~NodesCollector() override = default; - - const FuncGraphToAnfNodeMap &nodes_analysis() const { return nodes_analysis_; } - size_t size() const override { return nodes_analysis_.size(); } - void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); } - - void OnDropFuncGraph(FuncGraphPtr fg) override { (void)nodes_analysis_.erase(fg); } - - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - - FuncGraphToAnfNodeMap nodes_analysis_; - - protected: - void ExtraReset() override { nodes_analysis_.clear(); } - void OnAddNode(AnfNodePtr n) override; - void OnDropNode(AnfNodePtr n) override; -}; - struct CNodeIndexHasher { std::size_t operator()(const CNodeIndexPairPtr pair) const { MS_EXCEPTION_IF_NULL(pair); @@ -204,59 +166,21 @@ struct CNodeIndexEqual { } }; -template , class CollectorEqual = std::equal_to> -class CounterAnfNodeCollector : public DepCollector { - public: - explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} - ~CounterAnfNodeCollector() override = default; - FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; } - - size_t size() const override { return count_nodes_map_.size(); } - void OnAddFuncGraph(FuncGraphPtr fg) final { - count_nodes_map_[fg] = OrderedMap(); - } - void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } - - bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count); - bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count); - bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count); - - FuncGraphToAnfNodeCounterMap count_nodes_map_; - - protected: - void ExtraReset() override { count_nodes_map_.clear(); } -}; - -class ValueNodesCollector final : public CounterAnfNodeCollector { - public: - explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} - ~ValueNodesCollector() override = default; - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; -}; - -// Record the CNode and its input index, who points to the function graph. -class FuncGraphUsersCNodeIndexCollector final - : public CounterAnfNodeCollector { +// graphs analysis which compute in write, read needn't recompute +class DepCollector : public FuncGraphAnalysis { public: - explicit FuncGraphUsersCNodeIndexCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} - ~FuncGraphUsersCNodeIndexCollector() override = default; - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; -}; + explicit DepCollector(const FuncGraphManager *manager); + ~DepCollector() override = default; -class FVDirectCollector final : public CounterAnfNodeCollector { - public: - explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} - ~FVDirectCollector() override = default; - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; + void Reset() { ExtraReset(); } + void OnInvalidateCollector() { Reset(); } protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; + // inherit from FuncGraphAnalysis + void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override; + void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override; + // subclass can override; + virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {} }; class CounterFuncGraphCollector : public DepCollector { @@ -278,50 +202,27 @@ class CounterFuncGraphCollector : public DepCollector { void ExtraReset() override { count_func_graphs_map_.clear(); } }; -class FuncGraphChildDirect final : public CounterFuncGraphCollector { - public: - explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - - ~FuncGraphChildDirect() override = default; - - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; -}; - -// graph's all parents, parentsdirect have a map, which key is graph, value is this graph's all direct and proxy -// parents: -// 1.proxy parent: graph g use graph f, key is g, value is ParentProxy(f) because f's parent will be g's parent -// 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f -class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector { +template , class CollectorEqual = std::equal_to> +class CounterAnfNodeCollector : public DepCollector { public: - explicit FuncGraphParentsDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} - ~FuncGraphParentsDirectCollector() override = default; - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; -}; + explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} + ~CounterAnfNodeCollector() override = default; + FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; } -// graph's all used graphs: key is g, value is g used graph -class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { - public: - explicit FuncGraphsUsedCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - ~FuncGraphsUsedCollector() override = default; + size_t size() const override { return count_nodes_map_.size(); } + void OnAddFuncGraph(FuncGraphPtr fg) final { + count_nodes_map_[fg] = OrderedMap(); + } + void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; -}; + bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count); + bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count); + bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count); -class FuncGraphJDirectCollector final : public CounterFuncGraphCollector { - public: - explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} - void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override; - ~FuncGraphJDirectCollector() override = default; + FuncGraphToAnfNodeCounterMap count_nodes_map_; protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; + void ExtraReset() override { count_nodes_map_.clear(); } }; using FuncGraphToFuncGraphSetMap = OrderedMap; diff --git a/tests/ut/cpp/ir/manager_test.cc b/tests/ut/cpp/ir/manager_test.cc index 8816277c49..25e66036a1 100644 --- a/tests/ut/cpp/ir/manager_test.cc +++ b/tests/ut/cpp/ir/manager_test.cc @@ -132,18 +132,6 @@ class NestingSpecs { CheckAnfNodeCounter(counter_p); return; } - - auto counter_pair = dynamic_pointer_cast>(results); - if (counter_pair != nullptr) { - CheckCNodeIndexPairCounter(counter_pair); - return; - } - - auto nodes = dynamic_pointer_cast(results); - if (nodes != nullptr) { - CheckNodes(nodes); - return; - } } private: @@ -205,33 +193,7 @@ class NestingSpecs { ASSERT_EQ(clean_results, expected_); } - void CheckNodes(std::shared_ptr results) { - std::map> clean_results; - for (auto& iter : results->nodes_analysis()) { - auto key = iter.first; - auto value = iter.second; - if (key == nullptr) { - continue; - } - std::string k = Name(key); - - std::set v; - for (auto& node : value) { - if (!node->isa() && !Name(node).empty()) { - v.insert(Name(node)); - } - } - - if (!v.empty()) { - clean_results[k] = v; - } - } - - ASSERT_EQ(clean_results, expected_); - } - // Add CheckNesting function - void CheckAnfNodeCounter(std::shared_ptr> results) { std::map> clean_results; for (auto& iter : results->count_nodes_map()) { @@ -258,32 +220,6 @@ class NestingSpecs { ASSERT_EQ(clean_results, expected_); } - void CheckCNodeIndexPairCounter(std::shared_ptr> results) { - std::map> clean_results; - for (auto& iter : results->count_nodes_map()) { - auto key = iter.first; - auto value = iter.second; - if (key == nullptr) { - continue; - } - std::string k = Name(key); - - std::set v; - for (auto& node : value) { - auto fg = node.first->first; - if (!Name(fg).empty()) { - v.insert(Name(fg)); - } - } - - if (!v.empty()) { - clean_results[k] = v; - } - } - - ASSERT_EQ(clean_results, expected_); - } - void CheckGraphCounter(std::shared_ptr results) { std::map> clean_results; for (auto& iter : results->count_func_graphs_map()) { @@ -471,17 +407,10 @@ std::vector MakeNestedGraph2() { } // Add TestManager::CheckManager function to checkout the result - void TestManager::CheckAnalysisSize(std::shared_ptr mng) { auto size = mng->func_graphs().size(); - ASSERT_EQ(size + 1, mng->nodes().size()); ASSERT_EQ(size, mng->free_variables_total().size()); - ASSERT_EQ(size, mng->valuenodes().size()); - ASSERT_EQ(size, mng->free_variables_direct().size()); - ASSERT_EQ(size, mng->func_graph_cnodes_index().size()); - ASSERT_EQ(size, mng->func_graph_parents_direct().size()); - ASSERT_EQ(size, mng->func_graphs_used().size()); } TEST_F(TestManager, test_scalar_add_manual) { @@ -525,31 +454,26 @@ TEST_F(TestManager, test_nested_manual) { ASSERT_EQ(1, mng->roots().size()); CheckAnalysisSize(mng); - auto nodes = mng->nodes(); - ASSERT_EQ(3, nodes[nullptr].size()); - ASSERT_EQ(2, nodes[f].size()); - ASSERT_EQ(1, nodes[g].size()); + ASSERT_EQ(2, f->nodes().size()); + ASSERT_EQ(1, g->nodes().size()); auto users = mng->node_users(); for (auto& iter : users) { ASSERT_EQ(1, iter.second.size()); } - auto graphs_used = mng->func_graphs_used(); - ASSERT_EQ(1, graphs_used[f].size()); - ASSERT_EQ(0, graphs_used[g].size()); + ASSERT_EQ(1, f->func_graph_value_nodes().size()); + ASSERT_EQ(0, g->func_graph_value_nodes().size()); - auto fv_direct = mng->free_variables_direct(); - ASSERT_EQ(0, fv_direct[f].size()); - ASSERT_EQ(1, fv_direct[g].size()); + ASSERT_EQ(0, f->free_variables().size()); + ASSERT_EQ(1, g->free_variables().size()); auto fv_total = mng->free_variables_total(); ASSERT_EQ(0, fv_total[f].size()); ASSERT_EQ(1, fv_total[g].size()); - auto cnodes = mng->func_graph_cnodes_index(); - ASSERT_EQ(0, cnodes[f].size()); - ASSERT_EQ(1, cnodes[g].size()); + ASSERT_EQ(0, f->func_graph_cnodes_index().size()); + ASSERT_EQ(1, g->func_graph_cnodes_index().size()); } TEST_F(TestManager, test_deep_nested2_manual) { @@ -567,7 +491,7 @@ TEST_F(TestManager, test_deep_nested2_manual) { ASSERT_EQ(3, mng->func_graphs().size()); ASSERT_EQ(1, mng->roots().size()); - ASSERT_EQ(4, mng->nodes().size()); + ASSERT_EQ(4, gfn->nodes().size()); ASSERT_EQ(20, mng->all_nodes().size()); ASSERT_EQ(25, mng->node_users().size()); CheckAnalysisSize(mng); @@ -631,7 +555,6 @@ TEST_F(TestManager, test_deep_nested_manual) { ASSERT_EQ(3, mng->func_graphs().size()); ASSERT_EQ(1, mng->roots().size()); - ASSERT_EQ(4, mng->nodes().size()); ASSERT_EQ(20, mng->all_nodes().size()); CheckAnalysisSize(mng); } @@ -716,12 +639,12 @@ TEST_F(TestManager, test_drop_root) { FuncGraphPtr fg = getPyFun("ir_get_fn"); auto mng = Manage(fg); - const FuncGraphToAnfNodeMap& nodes = mng->nodes(); - ASSERT_TRUE(nodes.find(fg) != nodes.end()); + const auto &fgs = mng->func_graphs(); + ASSERT_TRUE(fgs.contains(fg)); FuncGraphSet s; s.add(fg); mng->MaybeDropFuncGraphs(s); - ASSERT_TRUE(nodes.find(fg) != nodes.end()); + ASSERT_TRUE(fgs.contains(fg)); } TEST_F(TestManager, test_keep_roots) { diff --git a/tests/ut/cpp/optimizer/cconv_test.cc b/tests/ut/cpp/optimizer/cconv_test.cc index 0b47c78cd3..8bd6957e85 100644 --- a/tests/ut/cpp/optimizer/cconv_test.cc +++ b/tests/ut/cpp/optimizer/cconv_test.cc @@ -26,15 +26,14 @@ namespace mindspore { void CheckNoFreeVariables(FuncGraphPtr root) { auto mng = Manage(root); - for (auto &iter : mng->nodes()) { - auto g = iter.first; - auto nodes = iter.second; + for (auto &iter : mng->func_graphs()) { + auto g = iter; if (g == nullptr) { continue; } - ASSERT_TRUE(g->parent() == nullptr); + auto nodes = g->nodes(); for (auto &node : nodes) { ASSERT_EQ(node->func_graph(), g); auto cnode = node->cast(); From 737bfc9595323949b91ee046a154b96df3d56b65 Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Fri, 22 May 2020 15:08:32 +0800 Subject: [PATCH 3/4] Use FuncGraph generation number replacing set. --- mindspore/ccsrc/ir/func_graph.cc | 6 ++++++ mindspore/ccsrc/ir/func_graph.h | 3 +++ mindspore/ccsrc/ir/manager.cc | 27 +++++++++++++-------------- mindspore/ccsrc/ir/manager.h | 4 ++-- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/mindspore/ccsrc/ir/func_graph.cc b/mindspore/ccsrc/ir/func_graph.cc index 4dc84cfa01..4833e3838b 100644 --- a/mindspore/ccsrc/ir/func_graph.cc +++ b/mindspore/ccsrc/ir/func_graph.cc @@ -47,6 +47,7 @@ FuncGraph::FuncGraph() : flags_(), transforms_(), parameter_default_value_(), + seen_(0), parameters_(), has_vararg_(false), has_kwarg_(false), @@ -981,6 +982,11 @@ void FuncGraph::SetEffectDepends(const std::vector &depend_inputs) { } } +size_t NewFgSeenGeneration() { + static size_t fg_seen_generation = 0; + return ++fg_seen_generation; +} + const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared("FuncGraph"); const char kFuncGraphFlagUndetermined[] = "Undeterminate"; } // namespace mindspore diff --git a/mindspore/ccsrc/ir/func_graph.h b/mindspore/ccsrc/ir/func_graph.h index 02a18a9809..f4c9d7079f 100644 --- a/mindspore/ccsrc/ir/func_graph.h +++ b/mindspore/ccsrc/ir/func_graph.h @@ -289,6 +289,7 @@ class FuncGraph : public FuncGraphBase { // parameter default value std::map parameter_default_value_; std::unordered_map make_ref_params_; + size_t seen_; std::list GetOrderedCnodes(); void EraseUnusedNodeInOrder(const AnfNodePtr &n); @@ -364,6 +365,8 @@ inline CNodePtr NewCNode(const std::vector &inputs, const FuncGraphP return fg->NewCNode(inputs); } +size_t NewFgSeenGeneration(); + // Find the root cnodes of a segment of cnodes. std::shared_ptr> FindRoots(const std::vector &segment); // Find the leaf cnodes of a segment of cnodes. diff --git a/mindspore/ccsrc/ir/manager.cc b/mindspore/ccsrc/ir/manager.cc index 00b31d543e..cfaa84a05b 100644 --- a/mindspore/ccsrc/ir/manager.cc +++ b/mindspore/ccsrc/ir/manager.cc @@ -755,8 +755,8 @@ void DepComputer::Recompute(const FuncGraphPtr &fg) { } } -FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) { - if (path == nullptr || path->contains(fg)) { +FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, size_t seen_num) { + if (fg->seen_ == seen_num) { return std::make_shared(); } FuncGraphSetPtr parents = std::make_shared(); @@ -770,9 +770,9 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f // Search the fv in fg's child func graph. auto &fg_value_nodes = fg->func_graph_value_nodes(); for (auto &fg_value_node : fg_value_nodes) { - path->add(fg); + fg->seen_ = seen_num; auto gt = GetValueNode(fg_value_node.first); - parents->update(SeekParents(gt, path)); + parents->update(SeekParents(gt, seen_num)); } (void)parents->erase(fg); return parents; @@ -780,7 +780,7 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { MS_EXCEPTION_IF_NULL(fg); - func_graph_parents_total_analysis_[fg].update(SeekParents(fg)); + func_graph_parents_total_analysis_[fg].update(SeekParents(fg, NewFgSeenGeneration())); } bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) { @@ -968,9 +968,8 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::listcontains(fg)) { +bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) { + if (fg->seen_ == seen_num) { MS_LOG(DEBUG) << fg->ToString() << " had been checked"; return false; } @@ -978,19 +977,20 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt if (!j_fg_value_nodes.empty()) { // check g1->J(fg)->g2->g cycle; auto contains_j = - std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(), - [path](const std::pair iter) { return !path->contains(GetValueNode(iter.first)); }); + std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(), [seen_num](const std::pair iter) { + return GetValueNode(iter.first)->seen_ != seen_num; + }); if (contains_j != j_fg_value_nodes.end()) { MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")"; return true; } } - path->add(fg); + fg->seen_ = seen_num; // check if func graphs used contains J(func_graph); for (auto &item : fg->func_graph_value_nodes()) { auto used_g = GetValueNode(item.first); - if (SeekJ(used_g, path)) { + if (SeekJ(used_g, seen_num)) { MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)"; return true; } @@ -1000,7 +1000,6 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt } void FuncGraphJTotalComputer::RealRecompute(FuncGraphPtr fg) { - std::shared_ptr path = std::make_shared(); - this->j_total_analysis_[fg] = SeekJ(fg, path); + this->j_total_analysis_[fg] = SeekJ(fg, NewFgSeenGeneration()); } } // namespace mindspore diff --git a/mindspore/ccsrc/ir/manager.h b/mindspore/ccsrc/ir/manager.h index 06b2859fea..d748a08593 100644 --- a/mindspore/ccsrc/ir/manager.h +++ b/mindspore/ccsrc/ir/manager.h @@ -283,7 +283,7 @@ class FuncGraphParentsTotalComputer final : public DepComputer { void RealRecompute(FuncGraphPtr fg) override; private: - FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared()); + FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, size_t seen_num); }; using FuncGraphToFuncGraphMap = OrderedMap; @@ -423,7 +423,7 @@ class FuncGraphJTotalComputer final : public DepComputer { void ExtraReset() override { j_total_analysis_.clear(); } void RealRecompute(FuncGraphPtr fg) override; - bool SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path); + bool SeekJ(const FuncGraphPtr &fg, size_t seen_num); }; class FuncGraphManager : public std::enable_shared_from_this { From dbb86cb1befadbc60b64d80889f1c0b5cfc0cc0a Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Sun, 24 May 2020 18:16:37 +0800 Subject: [PATCH 4/4] Adjust some routines of FG and FGM, about the nodes info. IF. --- mindspore/ccsrc/ir/func_graph.cc | 92 ++++++++++++++--------- mindspore/ccsrc/ir/func_graph.h | 32 ++++---- mindspore/ccsrc/ir/func_graph_cloner.cc | 6 +- mindspore/ccsrc/ir/manager.cc | 78 +++++++++---------- mindspore/ccsrc/parallel/step_parallel.cc | 2 +- tests/ut/cpp/ir/manager_test.cc | 4 +- 6 files changed, 113 insertions(+), 101 deletions(-) diff --git a/mindspore/ccsrc/ir/func_graph.cc b/mindspore/ccsrc/ir/func_graph.cc index 4833e3838b..c5d7639e2e 100644 --- a/mindspore/ccsrc/ir/func_graph.cc +++ b/mindspore/ccsrc/ir/func_graph.cc @@ -198,7 +198,7 @@ GraphDebugInfoPtr FuncGraph::debug_info() { const AnfNodeSet &FuncGraph::nodes() { return nodes_; } -void FuncGraph::CopyNodes(const AnfNodeSet &other_nodes) { nodes_ = other_nodes; } +void FuncGraph::CopyNodes(const FuncGraphPtr &source) { nodes_ = source->nodes(); } void FuncGraph::ClearNodes() { nodes_.clear(); } @@ -215,7 +215,12 @@ void FuncGraph::DropNode(AnfNodePtr node) { const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; } -void FuncGraph::CopyValueNodes(const AnfNodeCounterMap &other_value_nodes) { value_nodes_ = other_value_nodes; } +void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) { + auto &others = source->value_nodes(); + for (auto it = others.begin(); it != others.end(); it++) { + AddValueNode(it->first, it->second); + } +} void FuncGraph::ClearValueNodes() { value_nodes_.clear(); } @@ -243,9 +248,9 @@ void FuncGraph::DropValueNode(AnfNodePtr node) { const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; } -void FuncGraph::CopyFreeVariables(const AnfNodeCounterMap &others) { - auto it = others.begin(); - for (; it != others.end(); it++) { +void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) { + auto &others = source->free_variables(); + for (auto it = others.begin(); it != others.end(); it++) { if (it->first->func_graph().get() != this) { (void)AddFreeVariable(it->first, it->second); } @@ -313,31 +318,37 @@ std::vector FuncGraph::free_variables_func_graphs() { return func_graphs; } -const AnfNodeCounterMap &FuncGraph::func_graph_value_nodes() { return func_graph_value_nodes_; } +const FuncGraphCounterMap &FuncGraph::func_graphs_used() { return func_graphs_used_; } -void FuncGraph::CopyFuncGraphValueNodes(const AnfNodeCounterMap &others) { func_graph_value_nodes_ = others; } +void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) { + auto &others = source->func_graphs_used(); + for (auto it = others.begin(); it != others.end(); it++) { + (void)AddFuncGraphUsed(it->first, it->second); + } + func_graphs_used_.erase(source); +} -void FuncGraph::ClearFuncGraphValueNodes() { func_graph_value_nodes_.clear(); } +void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); } -bool FuncGraph::AddFuncGraphValueNode(AnfNodePtr node, int count) { - if (func_graph_value_nodes_.count(node) == 0) { - func_graph_value_nodes_[node] = count; +bool FuncGraph::AddFuncGraphUsed(FuncGraphPtr fg, int count) { + if (func_graphs_used_.count(fg) == 0) { + func_graphs_used_[fg] = count; return true; } else { - func_graph_value_nodes_[node] += count; + func_graphs_used_[fg] += count; return false; } } -bool FuncGraph::DropFuncGraphValueNode(AnfNodePtr node) { - if (func_graph_value_nodes_.count(node) != 0) { - if (func_graph_value_nodes_[node] == 1) { - (void)func_graph_value_nodes_.erase(node); +bool FuncGraph::DropFuncGraphUsed(FuncGraphPtr fg) { + if (func_graphs_used_.count(fg) != 0) { + if (func_graphs_used_[fg] == 1) { + (void)func_graphs_used_.erase(fg); return true; } else { - func_graph_value_nodes_[node]--; - if (func_graph_value_nodes_[node] < 0) { - MS_LOG(EXCEPTION) << "Count of value node(FuncGraph) '" << node + func_graphs_used_[fg]--; + if (func_graphs_used_[fg] < 0) { + MS_LOG(EXCEPTION) << "Count of FuncGraph '" << fg << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); } } @@ -354,11 +365,13 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() { const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; } -void FuncGraph::CopyFuncGraphCNodesIndex(const CNodeIndexCounterMap &others) { - auto it = others.begin(); - for (; it != others.end(); it++) { +void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) { + auto &others = source->func_graph_cnodes_index(); + for (auto it = others.begin(); it != others.end(); it++) { // Ignore the user graph who may own itself. - if (it->first->first->func_graph().get() != this) { + auto fg = it->first->first->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + if (fg.get() != this) { AddFuncGraphCNodeIndex(it->first, it->second); } } @@ -388,28 +401,33 @@ void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) { } } -const AnfNodeCounterMap &FuncGraph::j_func_graph_value_nodes() { return j_func_graph_value_nodes_; } +const FuncGraphCounterMap &FuncGraph::j_func_graphs() { return j_func_graphs_; } -void FuncGraph::CopyJFuncGraphValueNodes(const AnfNodeCounterMap &others) { j_func_graph_value_nodes_ = others; } +void FuncGraph::CopyJFuncGraphs(const FuncGraphPtr &source) { + auto &others = source->j_func_graphs(); + for (auto it = others.begin(); it != others.end(); it++) { + AddJFuncGraph(it->first, it->second); + } +} -void FuncGraph::ClearJFuncGraphValueNodes() { j_func_graph_value_nodes_.clear(); } +void FuncGraph::ClearJFuncGraphs() { j_func_graphs_.clear(); } -void FuncGraph::AddJFuncGraphValueNode(AnfNodePtr node, int count) { - if (j_func_graph_value_nodes_.count(node) == 0) { - j_func_graph_value_nodes_[node] = count; +void FuncGraph::AddJFuncGraph(FuncGraphPtr fg, int count) { + if (j_func_graphs_.count(fg) == 0) { + j_func_graphs_[fg] = count; } else { - j_func_graph_value_nodes_[node] += count; + j_func_graphs_[fg] += count; } } -void FuncGraph::DropJFuncGraphValueNode(AnfNodePtr node) { - if (j_func_graph_value_nodes_.count(node) != 0) { - if (j_func_graph_value_nodes_[node] == 1) { - (void)j_func_graph_value_nodes_.erase(node); +void FuncGraph::DropJFuncGraph(FuncGraphPtr fg) { + if (j_func_graphs_.count(fg) != 0) { + if (j_func_graphs_[fg] == 1) { + (void)j_func_graphs_.erase(fg); } else { - j_func_graph_value_nodes_[node]--; - if (j_func_graph_value_nodes_[node] < 0) { - MS_LOG(EXCEPTION) << "Count of value node(J FuncGraph) '" << node + j_func_graphs_[fg]--; + if (j_func_graphs_[fg] < 0) { + MS_LOG(EXCEPTION) << "Count of J FuncGraph '" << fg << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); } } diff --git a/mindspore/ccsrc/ir/func_graph.h b/mindspore/ccsrc/ir/func_graph.h index f4c9d7079f..8406f3b1ff 100644 --- a/mindspore/ccsrc/ir/func_graph.h +++ b/mindspore/ccsrc/ir/func_graph.h @@ -189,21 +189,21 @@ class FuncGraph : public FuncGraphBase { // get all nodes belonging to this func graph const AnfNodeSet &nodes(); - void CopyNodes(const AnfNodeSet &other_nodes); + void CopyNodes(const FuncGraphPtr &source); void ClearNodes(); void AddNode(AnfNodePtr node); void DropNode(AnfNodePtr node); // get all value_nodes belonging to this func graph const AnfNodeCounterMap &value_nodes(); - void CopyValueNodes(const AnfNodeCounterMap &other_value_nodes); + void CopyValueNodes(const FuncGraphPtr &source); void ClearValueNodes(); void AddValueNode(AnfNodePtr node, int count = 1); void DropValueNode(AnfNodePtr node); // get all free vars directly used in this func graph const AnfNodeCounterMap &free_variables(); - void CopyFreeVariables(const AnfNodeCounterMap &others); + void CopyFreeVariables(const FuncGraphPtr &source); void ClearFreeVariables(); bool AddFreeVariable(AnfNodePtr node, int count = 1); bool DropFreeVariable(AnfNodePtr node); @@ -218,25 +218,25 @@ class FuncGraph : public FuncGraphBase { std::vector free_variables_func_graphs(); // get all value nodes of func graph directly used by this func graph - const AnfNodeCounterMap &func_graph_value_nodes(); - void CopyFuncGraphValueNodes(const AnfNodeCounterMap &others); - void ClearFuncGraphValueNodes(); - bool AddFuncGraphValueNode(AnfNodePtr node, int count = 1); - bool DropFuncGraphValueNode(AnfNodePtr node); + const FuncGraphCounterMap &func_graphs_used(); + void CopyFuncGraphsUsed(const FuncGraphPtr &source); + void ClearFuncGraphsUsed(); + bool AddFuncGraphUsed(FuncGraphPtr fg, int count = 1); + bool DropFuncGraphUsed(FuncGraphPtr fg); // get all value nodes of J func graph directly used by this func graph - const AnfNodeCounterMap &j_func_graph_value_nodes(); - void CopyJFuncGraphValueNodes(const AnfNodeCounterMap &others); - void ClearJFuncGraphValueNodes(); - void AddJFuncGraphValueNode(AnfNodePtr node, int count = 1); - void DropJFuncGraphValueNode(AnfNodePtr node); + const FuncGraphCounterMap &j_func_graphs(); + void CopyJFuncGraphs(const FuncGraphPtr &source); + void ClearJFuncGraphs(); + void AddJFuncGraph(FuncGraphPtr fg, int count = 1); + void DropJFuncGraph(FuncGraphPtr fg); // get all func graphs nested used by this func graph const FuncGraphSet &func_graphs_used_total(); // get all user value nodes of this func graph, by CNode and its input's index const CNodeIndexCounterMap &func_graph_cnodes_index(); - void CopyFuncGraphCNodesIndex(const CNodeIndexCounterMap &other_value_nodes); + void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source); void ClearFuncGraphCNodesIndex(); void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1); void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node); @@ -311,13 +311,13 @@ class FuncGraph : public FuncGraphBase { AnfNodeCounterMap value_nodes_; // all func graph value nodes of the function - AnfNodeCounterMap func_graph_value_nodes_; + FuncGraphCounterMap func_graphs_used_; // all free variables of the function AnfNodeCounterMap free_variables_; // all value nodes calling J in the function - AnfNodeCounterMap j_func_graph_value_nodes_; + FuncGraphCounterMap j_func_graphs_; // all user value nodes of this func graph, recording by CNode and its input's index CNodeIndexCounterMap func_graph_cnodes_index_; diff --git a/mindspore/ccsrc/ir/func_graph_cloner.cc b/mindspore/ccsrc/ir/func_graph_cloner.cc index 99d7c316e9..db52e08348 100644 --- a/mindspore/ccsrc/ir/func_graph_cloner.cc +++ b/mindspore/ccsrc/ir/func_graph_cloner.cc @@ -153,9 +153,9 @@ void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) { if (!clone_all_used_graphs_) { return; } - auto &used = func_graph->func_graph_value_nodes(); - for (auto &fg_value_node : used) { - todo_.push_back({GetValueNode(fg_value_node.first), nullptr, {}}); + auto &used = func_graph->func_graphs_used(); + for (auto &fg : used) { + todo_.push_back({fg.first, nullptr, {}}); } } diff --git a/mindspore/ccsrc/ir/manager.cc b/mindspore/ccsrc/ir/manager.cc index cfaa84a05b..a21a794fee 100644 --- a/mindspore/ccsrc/ir/manager.cc +++ b/mindspore/ccsrc/ir/manager.cc @@ -196,7 +196,6 @@ void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) { return; } AddIntoManaged(func_graph); - MS_EXCEPTION_IF_NULL(signals_); std::vector para = func_graph->parameters(); AcquireNodes(para); std::vector return_vec({func_graph->get_return()}); @@ -301,7 +300,6 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool std::vector return_vec = {func_graph->get_return()}; todo.update(MaybeDropNodes(return_vec)); } - MS_EXCEPTION_IF_NULL(signals_); for (auto &fg : dropped) { MS_EXCEPTION_IF_NULL(fg); all_nodes_.difference_update(fg->parameters()); @@ -334,7 +332,6 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E } auto &users_node = node_users_[inp]; users_node.add(make_pair(node, index)); - MS_EXCEPTION_IF_NULL(signals_); AddEdge(node, index, inp); } } @@ -384,8 +381,6 @@ void FuncGraphManager::AcquireNodes(const std::vector &nodes) { FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector &nodes) { AnfNodeSet nodes_ordered(nodes); FuncGraphSetPtr func_graphs_to_check = std::make_shared(); - MS_EXCEPTION_IF_NULL(signals_); - while (!nodes_ordered.empty()) { AnfNodePtr node = nodes_ordered.pop(); MS_EXCEPTION_IF_NULL(node); @@ -475,13 +470,13 @@ inline void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr inp if (input->isa()) { fg->AddValueNode(input); if (IsValueNode(input)) { - if (fg->AddFuncGraphValueNode(input)) { - signals_->InvalidateComputer(); - } auto used = GetValueNode(input); used->AddFuncGraphCNodeIndex(std::make_shared(std::make_pair(node, index))); + if (fg->AddFuncGraphUsed(used)) { + signals_->InvalidateComputer(); + } if (IsPrimitiveCNode(node, prim::kPrimJ)) { - fg->AddJFuncGraphValueNode(input); + fg->AddJFuncGraph(used); } } } else if (fg != nullptr && fg != input->func_graph()) { @@ -496,13 +491,13 @@ inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr in if (input->isa()) { fg->DropValueNode(input); if (IsValueNode(input)) { - if (fg->DropFuncGraphValueNode(input)) { - signals_->InvalidateComputer(); - } auto used = GetValueNode(input); used->DropFuncGraphCNodeIndex(std::make_shared(std::make_pair(node, index))); + if (fg->DropFuncGraphUsed(used)) { + signals_->InvalidateComputer(); + } if (IsPrimitiveCNode(node, prim::kPrimJ)) { - fg->DropJFuncGraphValueNode(input); + fg->DropJFuncGraph(used); } } } else if (fg != nullptr && fg != input->func_graph()) { @@ -513,19 +508,19 @@ inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr in } inline void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { - target->CopyNodes(source->nodes()); - target->CopyValueNodes(source->value_nodes()); - target->CopyFuncGraphCNodesIndex(source->func_graph_cnodes_index()); - target->CopyFreeVariables(source->free_variables()); - target->CopyFuncGraphValueNodes(source->func_graph_value_nodes()); - target->CopyJFuncGraphValueNodes(source->j_func_graph_value_nodes()); + target->CopyNodes(source); + target->CopyValueNodes(source); + target->CopyFuncGraphCNodesIndex(source); + target->CopyFreeVariables(source); + target->CopyFuncGraphsUsed(source); + target->CopyJFuncGraphs(source); signals_->InvalidateComputer(); source->ClearNodes(); source->ClearValueNodes(); source->ClearFuncGraphCNodesIndex(); source->ClearFreeVariables(); - source->ClearFuncGraphValueNodes(); - source->ClearJFuncGraphValueNodes(); + source->ClearFuncGraphsUsed(); + source->ClearJFuncGraphs(); } FuncGraphTransaction FuncGraphManager::Transact() { @@ -768,10 +763,10 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f } // Search the fv in fg's child func graph. - auto &fg_value_nodes = fg->func_graph_value_nodes(); - for (auto &fg_value_node : fg_value_nodes) { + auto &fgs = fg->func_graphs_used(); + for (auto &item : fgs) { fg->seen_ = seen_num; - auto gt = GetValueNode(fg_value_node.first); + auto gt = item.first; parents->update(SeekParents(gt, seen_num)); } (void)parents->erase(fg); @@ -865,15 +860,15 @@ void FVTotalComputer::RealRecompute() { } } - auto &used = fg->func_graph_value_nodes(); + auto &used = fg->func_graphs_used(); for (auto &iter : used) { - auto p = manager->parent(GetValueNode(iter.first)); + auto p = manager->parent(iter.first); if (p == nullptr) { continue; } auto curr = fg; while (curr != p) { - (void)CounterFuncGraphCollector::Mod(curr, GetValueNode(iter.first), iter.second); + (void)CounterFuncGraphCollector::Mod(curr, iter.first, iter.second); curr = manager->parent(curr); } } @@ -899,8 +894,8 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { while (!todo.empty()) { todo_new.clear(); for (auto > : todo) { - for (auto &item : gt->func_graph_value_nodes()) { - auto used_fg = GetValueNode(item.first); + for (auto &item : gt->func_graphs_used()) { + auto used_fg = item.first; if (used_fg == fg) { func_graph_used_total_analysis_[fg].add(used_fg); continue; @@ -925,8 +920,8 @@ bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &f while (!todo.empty()) { todo_new.clear(); for (auto > : todo) { - for (auto &item : gt->func_graph_value_nodes()) { - auto used_g = GetValueNode(item.first); + for (auto &item : gt->func_graphs_used()) { + auto used_g = item.first; if (used_g == fg) { return true; } @@ -957,9 +952,9 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::listpush_back(fg); - auto &items = fg->func_graph_value_nodes(); + auto &items = fg->func_graphs_used(); for (auto iter = items.begin(); iter != items.end(); (void)iter++) { - CheckRecursiveGraphs(GetValueNode(iter->first), trace); + CheckRecursiveGraphs(iter->first, trace); } trace->pop_back(); if (!recursive_map_.count(fg)) { @@ -973,14 +968,13 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) { MS_LOG(DEBUG) << fg->ToString() << " had been checked"; return false; } - auto &j_fg_value_nodes = fg->j_func_graph_value_nodes(); - if (!j_fg_value_nodes.empty()) { + auto &j_fgs = fg->j_func_graphs(); + if (!j_fgs.empty()) { // check g1->J(fg)->g2->g cycle; - auto contains_j = - std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(), [seen_num](const std::pair iter) { - return GetValueNode(iter.first)->seen_ != seen_num; - }); - if (contains_j != j_fg_value_nodes.end()) { + auto contains_j = std::find_if(j_fgs.begin(), j_fgs.end(), [seen_num](const std::pair iter) { + return iter.first->seen_ != seen_num; + }); + if (contains_j != j_fgs.end()) { MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")"; return true; } @@ -988,8 +982,8 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) { fg->seen_ = seen_num; // check if func graphs used contains J(func_graph); - for (auto &item : fg->func_graph_value_nodes()) { - auto used_g = GetValueNode(item.first); + for (auto &item : fg->func_graphs_used()) { + auto used_g = item.first; if (SeekJ(used_g, seen_num)) { MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)"; return true; diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 3b679a473f..6c3b51347f 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -2187,7 +2187,7 @@ void MarkForwardCNode(const FuncGraphPtr &root) { SetForwardFlag(all_nodes); } else { for (auto &func_graph : graph_set) { - MS_LOG(INFO) << "The sub graph size of root is " << root->func_graph_value_nodes().size(); + MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size(); auto return_node = func_graph->get_return(); MS_EXCEPTION_IF_NULL(return_node); auto all_dfs_nodes = DeepLinkedGraphSearch(return_node); diff --git a/tests/ut/cpp/ir/manager_test.cc b/tests/ut/cpp/ir/manager_test.cc index 25e66036a1..7b1e4d8554 100644 --- a/tests/ut/cpp/ir/manager_test.cc +++ b/tests/ut/cpp/ir/manager_test.cc @@ -462,8 +462,8 @@ TEST_F(TestManager, test_nested_manual) { ASSERT_EQ(1, iter.second.size()); } - ASSERT_EQ(1, f->func_graph_value_nodes().size()); - ASSERT_EQ(0, g->func_graph_value_nodes().size()); + ASSERT_EQ(1, f->func_graphs_used().size()); + ASSERT_EQ(0, g->func_graphs_used().size()); ASSERT_EQ(0, f->free_variables().size()); ASSERT_EQ(1, g->free_variables().size());