1. Remove the records of user graphs; 2. Remove the records of user value nodes; 3. Remove the records of user cnodes; 4. Add the records of users, and the API to access the users of graph, value node, and cnode; 5. Fix issue:User cnode record may point to its own graph, when combine the user(caller) and used one(callee); 6. Fix issue:User graphs never update itself after its first creation.tags/v0.3.0-alpha
| @@ -263,18 +263,15 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() { | |||
| return used; | |||
| } | |||
| const FuncGraphCounterMap &FuncGraph::func_graph_users() { | |||
| auto mng = manager_.lock(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto &users = mng->func_graph_users(); | |||
| return users[shared_from_base<FuncGraph>()]; | |||
| } | |||
| const AnfNodeCounterMap &FuncGraph::func_graph_user_cnodes() { | |||
| const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { | |||
| auto mng = manager_.lock(); | |||
| if (mng == nullptr) { | |||
| MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString() | |||
| << " NodeInfo: " << trace::GetDebugInfo(debug_info()); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto &users = mng->func_graph_user_cnodes(); | |||
| return users[shared_from_base<FuncGraph>()]; | |||
| auto &cnode = mng->func_graph_cnodes_index(); | |||
| return cnode[shared_from_base<FuncGraph>()]; | |||
| } | |||
| FuncGraphPtr FuncGraph::parent() { | |||
| @@ -37,6 +37,7 @@ namespace mindspore { | |||
| using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>; | |||
| using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>; | |||
| using AnfNodeCounterMap = OrderedMap<AnfNodePtr, int>; | |||
| using CNodeIndexCounterMap = OrderedMap<CNodeIndexPairPtr, int, CNodeIndexHasher, CNodeIndexEqual>; | |||
| const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; | |||
| const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; | |||
| @@ -203,11 +204,8 @@ class FuncGraph : public FuncGraphBase { | |||
| // get all func graphs nested used by this func graph | |||
| const FuncGraphSet &func_graphs_used_total(); | |||
| // get all users of this func graph | |||
| const FuncGraphCounterMap &func_graph_users(); | |||
| // get all user cnodes of this func graph | |||
| const AnfNodeCounterMap &func_graph_user_cnodes(); | |||
| // get all user value nodes of this func graph | |||
| const CNodeIndexCounterMap &func_graph_cnodes_index(); | |||
| // Return the parent of this graph. | |||
| FuncGraphPtr parent(); | |||
| @@ -182,9 +182,11 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func | |||
| } | |||
| target_func_graph->set_return(return_node); | |||
| auto &value_nodes = manager_->func_graph_valuenodes()[func_graph]; | |||
| for (auto &value_node : value_nodes) { | |||
| CloneValueNode(value_node.first, target_func_graph); | |||
| auto &cnodes = manager_->func_graph_cnodes_index()[func_graph]; | |||
| for (auto &cnode : cnodes) { | |||
| auto parent = cnode.first->first->cast<CNodePtr>(); | |||
| auto valuenode = parent->input(cnode.first->second); | |||
| CloneValueNode(valuenode, target_func_graph); | |||
| } | |||
| } | |||
| @@ -386,8 +388,8 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph | |||
| if (lift_params.empty()) { | |||
| return; | |||
| } | |||
| for (auto &user : func_graph_user->func_graph_users()) { | |||
| LiftParameters(user.first, func_graph_user, lift_params); | |||
| for (auto &cnode : func_graph_user->func_graph_cnodes_index()) { | |||
| LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params); | |||
| } | |||
| } | |||
| @@ -395,8 +397,8 @@ void Cloner::Lift() { | |||
| for (auto &func_graph_params : repl_func_graph_params_) { | |||
| auto &func_graph = func_graph_params.first; | |||
| auto ¶ms = func_graph_params.second; | |||
| for (auto &user : func_graph->func_graph_users()) { | |||
| LiftParameters(user.first, func_graph, params); | |||
| for (auto &cnode : func_graph->func_graph_cnodes_index()) { | |||
| LiftParameters(cnode.first->first->func_graph(), func_graph, params); | |||
| } | |||
| } | |||
| } | |||
| @@ -78,13 +78,16 @@ void FuncGraphManager::Reset() { | |||
| node_users_ = NodeUsersMap(); | |||
| signals_ = std::make_shared<Signals>(); | |||
| // FuncGraph --> AnfNode | |||
| nodes_ = std::make_shared<NodesCollector>(this); | |||
| // FuncGraph --> {AnfNode, Count} | |||
| valuenodes_ = std::make_shared<ValueNodesCollector>(this); | |||
| free_variables_direct_ = std::make_shared<FVDirectCollector>(this); | |||
| func_graph_valuenodes_ = std::make_shared<FuncGraphValueNodesCollector>(this); | |||
| func_graph_cnodes_index_ = std::make_shared<FuncGraphUsersCNodeIndexCollector>(this); | |||
| // FuncGraph --> {FuncGraph, Count} | |||
| func_graphs_used_ = std::make_shared<FuncGraphsUsedCollector>(this); | |||
| func_graph_users_ = std::make_shared<FuncGraphUsersCollector>(this); | |||
| func_graph_user_cnodes_ = std::make_shared<FuncGraphUserNodesCollector>(this); | |||
| func_graph_child_direct_ = std::make_shared<FuncGraphChildDirect>(this); | |||
| func_graph_parents_direct_ = std::make_shared<FuncGraphParentsDirectCollector>(this); | |||
| func_graph_j_direct_ = std::make_shared<FuncGraphJDirectCollector>(this); | |||
| @@ -300,9 +303,9 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool | |||
| MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString(); | |||
| continue; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(func_graph_users_); | |||
| auto &users = func_graph_users_->count_func_graphs_map()[func_graph]; | |||
| if (!users.empty() && !ignore_users) { | |||
| MS_EXCEPTION_IF_NULL(func_graph_cnodes_index_); | |||
| auto &users_cnode_index = func_graph_cnodes_index_->count_nodes_map()[func_graph]; | |||
| if (!users_cnode_index.empty() && !ignore_users) { | |||
| MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); | |||
| continue; | |||
| } | |||
| @@ -472,10 +475,6 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t | |||
| node->set_scope(scope); | |||
| } | |||
| } | |||
| for (auto &used : source->func_graphs_used()) { | |||
| (void)func_graph_users_->Inc(used.first, target, used.second); | |||
| (void)this->func_graph_users()[used.first].erase(source); | |||
| } | |||
| for (auto &child : this->func_graph_child_direct()[source]) { | |||
| (void)func_graph_parents_direct_->Inc(child.first, target, child.second); | |||
| (void)this->func_graph_parents_direct()[child.first].erase(source); | |||
| @@ -661,7 +660,9 @@ DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAna | |||
| void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } | |||
| bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { | |||
| template <typename ValueT, class CollectorHash, class CollectorEqual> | |||
| bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Inc(const FuncGraphPtr &func_graph, | |||
| const ValueT &key, int count) { | |||
| auto &d = count_nodes_map_[func_graph]; | |||
| if (d.count(key) == 0) { | |||
| d[key] = count; | |||
| @@ -672,7 +673,9 @@ bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodeP | |||
| return false; | |||
| } | |||
| bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { | |||
| template <typename ValueT, class CollectorHash, class CollectorEqual> | |||
| bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Dec(const FuncGraphPtr &func_graph, | |||
| const ValueT &key, int count) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto &d = count_nodes_map_[func_graph]; | |||
| if (d.count(key) != 0) { | |||
| @@ -682,7 +685,7 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodeP | |||
| } else { | |||
| d[key] -= count; | |||
| if (d[key] < 0) { | |||
| MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() | |||
| MS_LOG(EXCEPTION) << "Count of key '" << key | |||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||
| } | |||
| } | |||
| @@ -690,52 +693,15 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodeP | |||
| return false; | |||
| } | |||
| bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count) { | |||
| template <typename ValueT, class CollectorHash, class CollectorEqual> | |||
| bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Mod(const FuncGraphPtr &func_graph, | |||
| const ValueT &key, int count) { | |||
| if (count > 0) { | |||
| return Inc(func_graph, key, count); | |||
| } else if (count < 0) { | |||
| return Dec(func_graph, key, -count); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() | |||
| << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||
| } | |||
| } | |||
| bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { | |||
| auto &d = count_func_graphs_map_[func_graph]; | |||
| if (d.count(key) == 0) { | |||
| d[key] = count; | |||
| return true; | |||
| } else { | |||
| d[key] += count; | |||
| } | |||
| return false; | |||
| } | |||
| bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { | |||
| auto &d = count_func_graphs_map_[func_graph]; | |||
| if (d.count(key) != 0) { | |||
| if (d[key] == count) { | |||
| (void)d.erase(key); | |||
| return true; | |||
| } else { | |||
| d[key] -= count; | |||
| if (d[key] < 0) { | |||
| MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() | |||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||
| } | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) { | |||
| if (count > 0) { | |||
| return Inc(func_graph, key, count); | |||
| } else if (count < 0) { | |||
| return Dec(func_graph, key, -count); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() | |||
| MS_LOG(EXCEPTION) << "Count of key '" << key | |||
| << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||
| } | |||
| } | |||
| @@ -754,16 +720,21 @@ void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||
| (void)count_nodes_map_.erase(src); | |||
| } | |||
| // if inp is a graph ValueNode, this graph's FuncGraphValueNodesCollector's value is inp self | |||
| void FuncGraphValueNodesCollector::OnModEdge(AnfNodePtr, int, AnfNodePtr inp, EdgeProcessDirection direction) { | |||
| void FuncGraphUsersCNodeIndexCollector::OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, | |||
| EdgeProcessDirection direction) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (IsValueNode<FuncGraph>(inp)) { | |||
| (void)Mod(GetValueNode<FuncGraphPtr>(inp), inp, direction); | |||
| (void)Mod(GetValueNode<FuncGraphPtr>(inp), std::make_shared<CNodeIndexPair>(std::make_pair(node, index)), | |||
| direction); | |||
| } | |||
| } | |||
| void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||
| void FuncGraphUsersCNodeIndexCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||
| for (auto &it : count_nodes_map_[src]) { | |||
| (void)Inc(dst, it.first, it.second); | |||
| // Ignore the user graph who may own itself. | |||
| if (dst != it.first->first->func_graph()) { | |||
| (void)Inc(dst, it.first, it.second); | |||
| } | |||
| } | |||
| (void)count_nodes_map_.erase(src); | |||
| } | |||
| @@ -794,6 +765,45 @@ static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) { | |||
| return gn; | |||
| } | |||
| bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { | |||
| auto &d = count_func_graphs_map_[func_graph]; | |||
| if (d.count(key) == 0) { | |||
| d[key] = count; | |||
| return true; | |||
| } else { | |||
| d[key] += count; | |||
| } | |||
| return false; | |||
| } | |||
| bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { | |||
| auto &d = count_func_graphs_map_[func_graph]; | |||
| if (d.count(key) != 0) { | |||
| if (d[key] == count) { | |||
| (void)d.erase(key); | |||
| return true; | |||
| } else { | |||
| d[key] -= count; | |||
| if (d[key] < 0) { | |||
| MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() | |||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||
| } | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) { | |||
| if (count > 0) { | |||
| return Inc(func_graph, key, count); | |||
| } else if (count < 0) { | |||
| return Dec(func_graph, key, -count); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() | |||
| << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||
| } | |||
| } | |||
| void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(inp); | |||
| @@ -859,32 +869,6 @@ void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) | |||
| (void)count_func_graphs_map_.erase(src); | |||
| } | |||
| void FuncGraphUsersCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (IsValueNode<FuncGraph>(inp)) { | |||
| (void)Mod(GetValueNode<FuncGraphPtr>(inp), node->func_graph(), direction); | |||
| } | |||
| } | |||
| void FuncGraphUsersCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr) { | |||
| // all graph use in src need to change to dst, so add dst user | |||
| (void)count_func_graphs_map_.erase(src); | |||
| } | |||
| void FuncGraphUserNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (IsValueNode<FuncGraph>(inp)) { | |||
| (void)Mod(GetValueNode<FuncGraphPtr>(inp), node, direction); | |||
| } | |||
| } | |||
| void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||
| for (auto &it : count_nodes_map_[src]) { | |||
| (void)Inc(dst, it.first, it.second); | |||
| } | |||
| (void)count_nodes_map_.erase(src); | |||
| } | |||
| void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { | |||
| if (IsValueNode<FuncGraph>(inp) && IsPrimitiveCNode(node, prim::kPrimJ)) { | |||
| (void)Mod(node->func_graph(), GetValueNode<FuncGraphPtr>(inp), direction); | |||
| @@ -100,8 +100,12 @@ struct Signals { | |||
| enum EdgeProcessDirection { kDecEdge = -1, kIncEdge = 1 }; | |||
| using CNodeIndexPair = std::pair<AnfNodePtr, int>; | |||
| using CNodeIndexPairPtr = std::shared_ptr<CNodeIndexPair>; | |||
| using FuncGraphToFuncGraphCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<FuncGraphPtr, int>>; | |||
| using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<AnfNodePtr, int>>; | |||
| template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>> | |||
| using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<ValueT, int, CollectorHash, CollectorEqual>>; | |||
| // analysis base class | |||
| class FuncGraphAnalysis { | |||
| @@ -174,46 +178,56 @@ class NodesCollector final : public DepCollector { | |||
| void OnDropNode(AnfNodePtr n) override; | |||
| }; | |||
| class CounterFuncGraphCollector : public DepCollector { | |||
| public: | |||
| explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {} | |||
| ~CounterFuncGraphCollector() override = default; | |||
| FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; } | |||
| // inherit from FuncGraphAnalysis | |||
| size_t size() const override { return count_func_graphs_map_.size(); } | |||
| void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); } | |||
| void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); } | |||
| bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||
| bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||
| bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||
| FuncGraphToFuncGraphCounterMap count_func_graphs_map_; | |||
| struct CNodeIndexHasher { | |||
| std::size_t operator()(const CNodeIndexPairPtr pair) const { | |||
| MS_EXCEPTION_IF_NULL(pair); | |||
| MS_EXCEPTION_IF_NULL(pair->first); | |||
| return hash_combine(pair->first->hash(), std::hash<int>()(pair->second)); | |||
| } | |||
| }; | |||
| protected: | |||
| void ExtraReset() override { count_func_graphs_map_.clear(); } | |||
| struct CNodeIndexEqual { | |||
| bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const { | |||
| if (lhs == nullptr || rhs == nullptr) { | |||
| return false; | |||
| } | |||
| if (lhs == rhs) { | |||
| return true; | |||
| } | |||
| if (lhs->first != rhs->first) { | |||
| return false; | |||
| } | |||
| if (lhs->second != rhs->second) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| }; | |||
| template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>> | |||
| class CounterAnfNodeCollector : public DepCollector { | |||
| public: | |||
| explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} | |||
| ~CounterAnfNodeCollector() override = default; | |||
| FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; } | |||
| FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; } | |||
| size_t size() const override { return count_nodes_map_.size(); } | |||
| void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); } | |||
| void OnAddFuncGraph(FuncGraphPtr fg) final { | |||
| count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>(); | |||
| } | |||
| void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } | |||
| bool Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); | |||
| bool Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); | |||
| bool Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); | |||
| bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count); | |||
| bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count); | |||
| bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count); | |||
| FuncGraphToAnfNodeCounterMap count_nodes_map_; | |||
| FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_; | |||
| protected: | |||
| void ExtraReset() override { count_nodes_map_.clear(); } | |||
| }; | |||
| class ValueNodesCollector final : public CounterAnfNodeCollector { | |||
| class ValueNodesCollector final : public CounterAnfNodeCollector<AnfNodePtr> { | |||
| public: | |||
| explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||
| ~ValueNodesCollector() override = default; | |||
| @@ -223,17 +237,19 @@ class ValueNodesCollector final : public CounterAnfNodeCollector { | |||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | |||
| }; | |||
| class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector { | |||
| // Record the CNode and its input index, who points to the function graph. | |||
| class FuncGraphUsersCNodeIndexCollector final | |||
| : public CounterAnfNodeCollector<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> { | |||
| public: | |||
| explicit FuncGraphValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||
| ~FuncGraphValueNodesCollector() override = default; | |||
| explicit FuncGraphUsersCNodeIndexCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||
| ~FuncGraphUsersCNodeIndexCollector() override = default; | |||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||
| protected: | |||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | |||
| }; | |||
| class FVDirectCollector final : public CounterAnfNodeCollector { | |||
| class FVDirectCollector final : public CounterAnfNodeCollector<AnfNodePtr> { | |||
| public: | |||
| explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||
| ~FVDirectCollector() override = default; | |||
| @@ -243,6 +259,25 @@ class FVDirectCollector final : public CounterAnfNodeCollector { | |||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | |||
| }; | |||
| class CounterFuncGraphCollector : public DepCollector { | |||
| public: | |||
| explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {} | |||
| ~CounterFuncGraphCollector() override = default; | |||
| FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; } | |||
| // inherit from FuncGraphAnalysis | |||
| size_t size() const override { return count_func_graphs_map_.size(); } | |||
| void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); } | |||
| void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); } | |||
| bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||
| bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||
| bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||
| FuncGraphToFuncGraphCounterMap count_func_graphs_map_; | |||
| protected: | |||
| void ExtraReset() override { count_func_graphs_map_.clear(); } | |||
| }; | |||
| class FuncGraphChildDirect final : public CounterFuncGraphCollector { | |||
| public: | |||
| explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||
| @@ -279,28 +314,6 @@ class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { | |||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | |||
| }; | |||
| // graph's all user graphs: key is g, value is graphs who used g | |||
| class FuncGraphUsersCollector final : public CounterFuncGraphCollector { | |||
| public: | |||
| explicit FuncGraphUsersCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||
| ~FuncGraphUsersCollector() override = default; | |||
| protected: | |||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | |||
| }; | |||
| // graph's all user cnodes: key is g, value is cnodes who used g | |||
| class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector { | |||
| public: | |||
| explicit FuncGraphUserNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||
| ~FuncGraphUserNodesCollector() override = default; | |||
| protected: | |||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | |||
| }; | |||
| class FuncGraphJDirectCollector final : public CounterFuncGraphCollector { | |||
| public: | |||
| explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||
| @@ -433,7 +446,9 @@ class ScopeComputer final : public DepComputer { | |||
| using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash>>; | |||
| class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector, public CounterFuncGraphCollector { | |||
| class FVTotalComputer final : public DepComputer, | |||
| public CounterAnfNodeCollector<AnfNodePtr>, | |||
| public CounterFuncGraphCollector { | |||
| public: | |||
| explicit FVTotalComputer(const FuncGraphManager *m) | |||
| : DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {} | |||
| @@ -549,18 +564,18 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||
| FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; } | |||
| FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; } | |||
| FuncGraphToAnfNodeCounterMap<AnfNodePtr> &valuenodes() const { return valuenodes_->count_nodes_map_; } | |||
| FuncGraphToAnfNodeCounterMap &free_variables_direct() const { return free_variables_direct_->count_nodes_map_; } | |||
| FuncGraphToAnfNodeCounterMap<AnfNodePtr> &free_variables_direct() const { | |||
| return free_variables_direct_->count_nodes_map_; | |||
| } | |||
| FuncGraphToAnfNodeCounterMap &func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; } | |||
| FuncGraphToAnfNodeCounterMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> &func_graph_cnodes_index() const { | |||
| return func_graph_cnodes_index_->count_nodes_map_; | |||
| } | |||
| FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } | |||
| FuncGraphToFuncGraphCounterMap &func_graph_users() const { return func_graph_users_->count_func_graphs_map_; } | |||
| FuncGraphToAnfNodeCounterMap &func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; } | |||
| FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const { | |||
| return func_graph_child_direct_->count_func_graphs_map_; | |||
| } | |||
| @@ -598,10 +613,8 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||
| std::shared_ptr<NodesCollector> nodes_; | |||
| std::shared_ptr<ValueNodesCollector> valuenodes_; | |||
| std::shared_ptr<FVDirectCollector> free_variables_direct_; | |||
| std::shared_ptr<FuncGraphValueNodesCollector> func_graph_valuenodes_; | |||
| std::shared_ptr<FuncGraphUsersCNodeIndexCollector> func_graph_cnodes_index_; | |||
| std::shared_ptr<FuncGraphsUsedCollector> func_graphs_used_; | |||
| std::shared_ptr<FuncGraphUsersCollector> func_graph_users_; | |||
| std::shared_ptr<FuncGraphUserNodesCollector> func_graph_user_cnodes_; | |||
| std::shared_ptr<FuncGraphChildDirect> func_graph_child_direct_; | |||
| std::shared_ptr<FuncGraphParentsDirectCollector> func_graph_parents_direct_; | |||
| std::shared_ptr<FuncGraphJDirectCollector> func_graph_j_direct_; | |||
| @@ -81,10 +81,10 @@ bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) { | |||
| } | |||
| bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) { | |||
| auto &users = fg->func_graph_users(); | |||
| auto &cnodes = fg->func_graph_cnodes_index(); | |||
| int n_use = | |||
| std::accumulate(users.begin(), users.end(), 0, | |||
| [](int sum, const std::pair<const FuncGraphPtr, int> &item) { return sum + item.second; }); | |||
| std::accumulate(cnodes.begin(), cnodes.end(), 0, | |||
| [](int sum, const std::pair<const CNodeIndexPairPtr, int> &item) { return sum + item.second; }); | |||
| return n_use == 1; | |||
| } | |||
| @@ -486,7 +486,8 @@ void CompileGraph::AddExternal(const LinConvertResult &result) { | |||
| } | |||
| void TraverseGraphMap( | |||
| const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphToAnfNodeCounterMap &cts, | |||
| const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, | |||
| const FuncGraphToAnfNodeCounterMap<AnfNodePtr> &cts, | |||
| const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { | |||
| MS_EXCEPTION_IF_NULL(manager_ptr); | |||
| MS_EXCEPTION_IF_NULL(tr); | |||