From d43ad79b50f1159868a5965264a29de0158e9b15 Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Mon, 27 Apr 2020 15:49:26 +0800 Subject: [PATCH] Optimize the collectors of manager which listen to the graphs and nodes changes. 1. Remove the records of user graphs; 2. Remove the records of user value nodes; 3. Remove the records of user cnodes; 4. Add the records of users, and the API to access the users of graph, value node, and cnode; 5. Fix issue:User cnode record may point to its own graph, when combine the user(caller) and used one(callee); 6. Fix issue:User graphs never update itself after its first creation. --- mindspore/ccsrc/ir/func_graph.cc | 17 +-- mindspore/ccsrc/ir/func_graph.h | 8 +- mindspore/ccsrc/ir/func_graph_cloner.cc | 16 ++- mindspore/ccsrc/ir/manager.cc | 154 ++++++++++------------ mindspore/ccsrc/ir/manager.h | 135 ++++++++++--------- mindspore/ccsrc/optimizer/irpass/inline.h | 6 +- mindspore/ccsrc/vm/transform.cc | 3 +- 7 files changed, 167 insertions(+), 172 deletions(-) diff --git a/mindspore/ccsrc/ir/func_graph.cc b/mindspore/ccsrc/ir/func_graph.cc index 8a58f320f1..40417a33da 100644 --- a/mindspore/ccsrc/ir/func_graph.cc +++ b/mindspore/ccsrc/ir/func_graph.cc @@ -263,18 +263,15 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() { return used; } -const FuncGraphCounterMap &FuncGraph::func_graph_users() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - auto &users = mng->func_graph_users(); - return users[shared_from_base()]; -} - -const AnfNodeCounterMap &FuncGraph::func_graph_user_cnodes() { +const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { auto mng = manager_.lock(); + if (mng == nullptr) { + MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString() + << " NodeInfo: " << trace::GetDebugInfo(debug_info()); + } MS_EXCEPTION_IF_NULL(mng); - auto &users = mng->func_graph_user_cnodes(); - return users[shared_from_base()]; + auto &cnode = mng->func_graph_cnodes_index(); + return cnode[shared_from_base()]; } FuncGraphPtr FuncGraph::parent() { diff --git a/mindspore/ccsrc/ir/func_graph.h b/mindspore/ccsrc/ir/func_graph.h index 9c3752cd81..bca5759807 100644 --- a/mindspore/ccsrc/ir/func_graph.h +++ b/mindspore/ccsrc/ir/func_graph.h @@ -37,6 +37,7 @@ namespace mindspore { using BaseRefCounterMap = OrderedMap; using FuncGraphCounterMap = OrderedMap; using AnfNodeCounterMap = OrderedMap; +using CNodeIndexCounterMap = OrderedMap; const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; @@ -203,11 +204,8 @@ class FuncGraph : public FuncGraphBase { // get all func graphs nested used by this func graph const FuncGraphSet &func_graphs_used_total(); - // get all users of this func graph - const FuncGraphCounterMap &func_graph_users(); - - // get all user cnodes of this func graph - const AnfNodeCounterMap &func_graph_user_cnodes(); + // get all user value nodes of this func graph + const CNodeIndexCounterMap &func_graph_cnodes_index(); // Return the parent of this graph. FuncGraphPtr parent(); diff --git a/mindspore/ccsrc/ir/func_graph_cloner.cc b/mindspore/ccsrc/ir/func_graph_cloner.cc index c086b8d7d1..c8012276f1 100644 --- a/mindspore/ccsrc/ir/func_graph_cloner.cc +++ b/mindspore/ccsrc/ir/func_graph_cloner.cc @@ -182,9 +182,11 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func } target_func_graph->set_return(return_node); - auto &value_nodes = manager_->func_graph_valuenodes()[func_graph]; - for (auto &value_node : value_nodes) { - CloneValueNode(value_node.first, target_func_graph); + auto &cnodes = manager_->func_graph_cnodes_index()[func_graph]; + for (auto &cnode : cnodes) { + auto parent = cnode.first->first->cast(); + auto valuenode = parent->input(cnode.first->second); + CloneValueNode(valuenode, target_func_graph); } } @@ -386,8 +388,8 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph if (lift_params.empty()) { return; } - for (auto &user : func_graph_user->func_graph_users()) { - LiftParameters(user.first, func_graph_user, lift_params); + for (auto &cnode : func_graph_user->func_graph_cnodes_index()) { + LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params); } } @@ -395,8 +397,8 @@ void Cloner::Lift() { for (auto &func_graph_params : repl_func_graph_params_) { auto &func_graph = func_graph_params.first; auto ¶ms = func_graph_params.second; - for (auto &user : func_graph->func_graph_users()) { - LiftParameters(user.first, func_graph, params); + for (auto &cnode : func_graph->func_graph_cnodes_index()) { + LiftParameters(cnode.first->first->func_graph(), func_graph, params); } } } diff --git a/mindspore/ccsrc/ir/manager.cc b/mindspore/ccsrc/ir/manager.cc index 150e68ef4d..1ed747eefd 100644 --- a/mindspore/ccsrc/ir/manager.cc +++ b/mindspore/ccsrc/ir/manager.cc @@ -78,13 +78,16 @@ void FuncGraphManager::Reset() { node_users_ = NodeUsersMap(); signals_ = std::make_shared(); + // FuncGraph --> AnfNode nodes_ = std::make_shared(this); + + // FuncGraph --> {AnfNode, Count} valuenodes_ = std::make_shared(this); free_variables_direct_ = std::make_shared(this); - func_graph_valuenodes_ = std::make_shared(this); + func_graph_cnodes_index_ = std::make_shared(this); + + // FuncGraph --> {FuncGraph, Count} func_graphs_used_ = std::make_shared(this); - func_graph_users_ = std::make_shared(this); - func_graph_user_cnodes_ = std::make_shared(this); func_graph_child_direct_ = std::make_shared(this); func_graph_parents_direct_ = std::make_shared(this); func_graph_j_direct_ = std::make_shared(this); @@ -300,9 +303,9 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString(); continue; } - MS_EXCEPTION_IF_NULL(func_graph_users_); - auto &users = func_graph_users_->count_func_graphs_map()[func_graph]; - if (!users.empty() && !ignore_users) { + MS_EXCEPTION_IF_NULL(func_graph_cnodes_index_); + auto &users_cnode_index = func_graph_cnodes_index_->count_nodes_map()[func_graph]; + if (!users_cnode_index.empty() && !ignore_users) { MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); continue; } @@ -472,10 +475,6 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t node->set_scope(scope); } } - for (auto &used : source->func_graphs_used()) { - (void)func_graph_users_->Inc(used.first, target, used.second); - (void)this->func_graph_users()[used.first].erase(source); - } for (auto &child : this->func_graph_child_direct()[source]) { (void)func_graph_parents_direct_->Inc(child.first, target, child.second); (void)this->func_graph_parents_direct()[child.first].erase(source); @@ -661,7 +660,9 @@ DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAna void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } -bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { +template +bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, + const ValueT &key, int count) { auto &d = count_nodes_map_[func_graph]; if (d.count(key) == 0) { d[key] = count; @@ -672,7 +673,9 @@ bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodeP return false; } -bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { +template +bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, + const ValueT &key, int count) { MS_EXCEPTION_IF_NULL(func_graph); auto &d = count_nodes_map_[func_graph]; if (d.count(key) != 0) { @@ -682,7 +685,7 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodeP } else { d[key] -= count; if (d[key] < 0) { - MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() + MS_LOG(EXCEPTION) << "Count of key '" << key << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); } } @@ -690,52 +693,15 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodeP return false; } -bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count) { +template +bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, + const ValueT &key, int count) { if (count > 0) { return Inc(func_graph, key, count); } else if (count < 0) { return Dec(func_graph, key, -count); } else { - MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() - << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); - } -} - -bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { - auto &d = count_func_graphs_map_[func_graph]; - if (d.count(key) == 0) { - d[key] = count; - return true; - } else { - d[key] += count; - } - return false; -} - -bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { - auto &d = count_func_graphs_map_[func_graph]; - if (d.count(key) != 0) { - if (d[key] == count) { - (void)d.erase(key); - return true; - } else { - d[key] -= count; - if (d[key] < 0) { - MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() - << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); - } - } - } - return false; -} - -bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) { - if (count > 0) { - return Inc(func_graph, key, count); - } else if (count < 0) { - return Dec(func_graph, key, -count); - } else { - MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() + MS_LOG(EXCEPTION) << "Count of key '" << key << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); } } @@ -754,16 +720,21 @@ void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { (void)count_nodes_map_.erase(src); } -// if inp is a graph ValueNode, this graph's FuncGraphValueNodesCollector's value is inp self -void FuncGraphValueNodesCollector::OnModEdge(AnfNodePtr, int, AnfNodePtr inp, EdgeProcessDirection direction) { +void FuncGraphUsersCNodeIndexCollector::OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, + EdgeProcessDirection direction) { + MS_EXCEPTION_IF_NULL(node); if (IsValueNode(inp)) { - (void)Mod(GetValueNode(inp), inp, direction); + (void)Mod(GetValueNode(inp), std::make_shared(std::make_pair(node, index)), + direction); } } -void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { +void FuncGraphUsersCNodeIndexCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { for (auto &it : count_nodes_map_[src]) { - (void)Inc(dst, it.first, it.second); + // Ignore the user graph who may own itself. + if (dst != it.first->first->func_graph()) { + (void)Inc(dst, it.first, it.second); + } } (void)count_nodes_map_.erase(src); } @@ -794,6 +765,45 @@ static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) { return gn; } +bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { + auto &d = count_func_graphs_map_[func_graph]; + if (d.count(key) == 0) { + d[key] = count; + return true; + } else { + d[key] += count; + } + return false; +} + +bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { + auto &d = count_func_graphs_map_[func_graph]; + if (d.count(key) != 0) { + if (d[key] == count) { + (void)d.erase(key); + return true; + } else { + d[key] -= count; + if (d[key] < 0) { + MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() + << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); + } + } + } + return false; +} + +bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) { + if (count > 0) { + return Inc(func_graph, key, count); + } else if (count < 0) { + return Dec(func_graph, key, -count); + } else { + MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() + << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); + } +} + void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(inp); @@ -859,32 +869,6 @@ void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) (void)count_func_graphs_map_.erase(src); } -void FuncGraphUsersCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - if (IsValueNode(inp)) { - (void)Mod(GetValueNode(inp), node->func_graph(), direction); - } -} - -void FuncGraphUsersCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr) { - // all graph use in src need to change to dst, so add dst user - (void)count_func_graphs_map_.erase(src); -} - -void FuncGraphUserNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - if (IsValueNode(inp)) { - (void)Mod(GetValueNode(inp), node, direction); - } -} - -void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto &it : count_nodes_map_[src]) { - (void)Inc(dst, it.first, it.second); - } - (void)count_nodes_map_.erase(src); -} - void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { if (IsValueNode(inp) && IsPrimitiveCNode(node, prim::kPrimJ)) { (void)Mod(node->func_graph(), GetValueNode(inp), direction); diff --git a/mindspore/ccsrc/ir/manager.h b/mindspore/ccsrc/ir/manager.h index 54c1e8a692..7f36b53205 100644 --- a/mindspore/ccsrc/ir/manager.h +++ b/mindspore/ccsrc/ir/manager.h @@ -100,8 +100,12 @@ struct Signals { enum EdgeProcessDirection { kDecEdge = -1, kIncEdge = 1 }; +using CNodeIndexPair = std::pair; +using CNodeIndexPairPtr = std::shared_ptr; + using FuncGraphToFuncGraphCounterMap = OrderedMap>; -using FuncGraphToAnfNodeCounterMap = OrderedMap>; +template , class CollectorEqual = std::equal_to> +using FuncGraphToAnfNodeCounterMap = OrderedMap>; // analysis base class class FuncGraphAnalysis { @@ -174,46 +178,56 @@ class NodesCollector final : public DepCollector { void OnDropNode(AnfNodePtr n) override; }; -class CounterFuncGraphCollector : public DepCollector { - public: - explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {} - ~CounterFuncGraphCollector() override = default; - FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; } - // inherit from FuncGraphAnalysis - size_t size() const override { return count_func_graphs_map_.size(); } - void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap(); } - void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); } - bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); - bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); - bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); - - FuncGraphToFuncGraphCounterMap count_func_graphs_map_; +struct CNodeIndexHasher { + std::size_t operator()(const CNodeIndexPairPtr pair) const { + MS_EXCEPTION_IF_NULL(pair); + MS_EXCEPTION_IF_NULL(pair->first); + return hash_combine(pair->first->hash(), std::hash()(pair->second)); + } +}; - protected: - void ExtraReset() override { count_func_graphs_map_.clear(); } +struct CNodeIndexEqual { + bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const { + if (lhs == nullptr || rhs == nullptr) { + return false; + } + if (lhs == rhs) { + return true; + } + if (lhs->first != rhs->first) { + return false; + } + if (lhs->second != rhs->second) { + return false; + } + return true; + } }; +template , class CollectorEqual = std::equal_to> class CounterAnfNodeCollector : public DepCollector { public: explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} ~CounterAnfNodeCollector() override = default; - FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; } size_t size() const override { return count_nodes_map_.size(); } - void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap(); } + void OnAddFuncGraph(FuncGraphPtr fg) final { + count_nodes_map_[fg] = OrderedMap(); + } void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } - bool Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); - bool Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); - bool Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); + bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count); + bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count); + bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count); - FuncGraphToAnfNodeCounterMap count_nodes_map_; + FuncGraphToAnfNodeCounterMap count_nodes_map_; protected: void ExtraReset() override { count_nodes_map_.clear(); } }; -class ValueNodesCollector final : public CounterAnfNodeCollector { +class ValueNodesCollector final : public CounterAnfNodeCollector { public: explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} ~ValueNodesCollector() override = default; @@ -223,17 +237,19 @@ class ValueNodesCollector final : public CounterAnfNodeCollector { void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; }; -class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector { +// Record the CNode and its input index, who points to the function graph. +class FuncGraphUsersCNodeIndexCollector final + : public CounterAnfNodeCollector { public: - explicit FuncGraphValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} - ~FuncGraphValueNodesCollector() override = default; + explicit FuncGraphUsersCNodeIndexCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} + ~FuncGraphUsersCNodeIndexCollector() override = default; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; protected: void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; }; -class FVDirectCollector final : public CounterAnfNodeCollector { +class FVDirectCollector final : public CounterAnfNodeCollector { public: explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} ~FVDirectCollector() override = default; @@ -243,6 +259,25 @@ class FVDirectCollector final : public CounterAnfNodeCollector { void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; }; +class CounterFuncGraphCollector : public DepCollector { + public: + explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {} + ~CounterFuncGraphCollector() override = default; + FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; } + // inherit from FuncGraphAnalysis + size_t size() const override { return count_func_graphs_map_.size(); } + void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap(); } + void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); } + bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); + bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); + bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); + + FuncGraphToFuncGraphCounterMap count_func_graphs_map_; + + protected: + void ExtraReset() override { count_func_graphs_map_.clear(); } +}; + class FuncGraphChildDirect final : public CounterFuncGraphCollector { public: explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} @@ -279,28 +314,6 @@ class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; }; -// graph's all user graphs: key is g, value is graphs who used g -class FuncGraphUsersCollector final : public CounterFuncGraphCollector { - public: - explicit FuncGraphUsersCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - ~FuncGraphUsersCollector() override = default; - - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; -}; - -// graph's all user cnodes: key is g, value is cnodes who used g -class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector { - public: - explicit FuncGraphUserNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - ~FuncGraphUserNodesCollector() override = default; - - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; -}; - class FuncGraphJDirectCollector final : public CounterFuncGraphCollector { public: explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} @@ -433,7 +446,9 @@ class ScopeComputer final : public DepComputer { using FVTotalMap = OrderedMap>; -class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector, public CounterFuncGraphCollector { +class FVTotalComputer final : public DepComputer, + public CounterAnfNodeCollector, + public CounterFuncGraphCollector { public: explicit FVTotalComputer(const FuncGraphManager *m) : DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {} @@ -549,18 +564,18 @@ class FuncGraphManager : public std::enable_shared_from_this { FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; } - FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; } - FuncGraphToAnfNodeCounterMap &free_variables_direct() const { return free_variables_direct_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &free_variables_direct() const { + return free_variables_direct_->count_nodes_map_; + } - FuncGraphToAnfNodeCounterMap &func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &func_graph_cnodes_index() const { + return func_graph_cnodes_index_->count_nodes_map_; + } FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } - FuncGraphToFuncGraphCounterMap &func_graph_users() const { return func_graph_users_->count_func_graphs_map_; } - - FuncGraphToAnfNodeCounterMap &func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; } - FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const { return func_graph_child_direct_->count_func_graphs_map_; } @@ -598,10 +613,8 @@ class FuncGraphManager : public std::enable_shared_from_this { std::shared_ptr nodes_; std::shared_ptr valuenodes_; std::shared_ptr free_variables_direct_; - std::shared_ptr func_graph_valuenodes_; + std::shared_ptr func_graph_cnodes_index_; std::shared_ptr func_graphs_used_; - std::shared_ptr func_graph_users_; - std::shared_ptr func_graph_user_cnodes_; std::shared_ptr func_graph_child_direct_; std::shared_ptr func_graph_parents_direct_; std::shared_ptr func_graph_j_direct_; diff --git a/mindspore/ccsrc/optimizer/irpass/inline.h b/mindspore/ccsrc/optimizer/irpass/inline.h index a7b6b975bb..8ebd0f6eb7 100644 --- a/mindspore/ccsrc/optimizer/irpass/inline.h +++ b/mindspore/ccsrc/optimizer/irpass/inline.h @@ -81,10 +81,10 @@ bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) { } bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) { - auto &users = fg->func_graph_users(); + auto &cnodes = fg->func_graph_cnodes_index(); int n_use = - std::accumulate(users.begin(), users.end(), 0, - [](int sum, const std::pair &item) { return sum + item.second; }); + std::accumulate(cnodes.begin(), cnodes.end(), 0, + [](int sum, const std::pair &item) { return sum + item.second; }); return n_use == 1; } diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index b14bf54869..9147f75fb2 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -486,7 +486,8 @@ void CompileGraph::AddExternal(const LinConvertResult &result) { } void TraverseGraphMap( - const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphToAnfNodeCounterMap &cts, + const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, + const FuncGraphToAnfNodeCounterMap &cts, const std::function(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { MS_EXCEPTION_IF_NULL(manager_ptr); MS_EXCEPTION_IF_NULL(tr);