1. Remove the records of user graphs; 2. Remove the records of user value nodes; 3. Remove the records of user cnodes; 4. Add the records of users, and the API to access the users of graph, value node, and cnode; 5. Fix issue:User cnode record may point to its own graph, when combine the user(caller) and used one(callee); 6. Fix issue:User graphs never update itself after its first creation.tags/v0.3.0-alpha
| @@ -263,18 +263,15 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() { | |||||
| return used; | return used; | ||||
| } | } | ||||
| const FuncGraphCounterMap &FuncGraph::func_graph_users() { | |||||
| auto mng = manager_.lock(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| auto &users = mng->func_graph_users(); | |||||
| return users[shared_from_base<FuncGraph>()]; | |||||
| } | |||||
| const AnfNodeCounterMap &FuncGraph::func_graph_user_cnodes() { | |||||
| const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { | |||||
| auto mng = manager_.lock(); | auto mng = manager_.lock(); | ||||
| if (mng == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString() | |||||
| << " NodeInfo: " << trace::GetDebugInfo(debug_info()); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(mng); | MS_EXCEPTION_IF_NULL(mng); | ||||
| auto &users = mng->func_graph_user_cnodes(); | |||||
| return users[shared_from_base<FuncGraph>()]; | |||||
| auto &cnode = mng->func_graph_cnodes_index(); | |||||
| return cnode[shared_from_base<FuncGraph>()]; | |||||
| } | } | ||||
| FuncGraphPtr FuncGraph::parent() { | FuncGraphPtr FuncGraph::parent() { | ||||
| @@ -37,6 +37,7 @@ namespace mindspore { | |||||
| using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>; | using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>; | ||||
| using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>; | using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>; | ||||
| using AnfNodeCounterMap = OrderedMap<AnfNodePtr, int>; | using AnfNodeCounterMap = OrderedMap<AnfNodePtr, int>; | ||||
| using CNodeIndexCounterMap = OrderedMap<CNodeIndexPairPtr, int, CNodeIndexHasher, CNodeIndexEqual>; | |||||
| const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; | const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; | ||||
| const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; | const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; | ||||
| @@ -203,11 +204,8 @@ class FuncGraph : public FuncGraphBase { | |||||
| // get all func graphs nested used by this func graph | // get all func graphs nested used by this func graph | ||||
| const FuncGraphSet &func_graphs_used_total(); | const FuncGraphSet &func_graphs_used_total(); | ||||
| // get all users of this func graph | |||||
| const FuncGraphCounterMap &func_graph_users(); | |||||
| // get all user cnodes of this func graph | |||||
| const AnfNodeCounterMap &func_graph_user_cnodes(); | |||||
| // get all user value nodes of this func graph | |||||
| const CNodeIndexCounterMap &func_graph_cnodes_index(); | |||||
| // Return the parent of this graph. | // Return the parent of this graph. | ||||
| FuncGraphPtr parent(); | FuncGraphPtr parent(); | ||||
| @@ -182,9 +182,11 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func | |||||
| } | } | ||||
| target_func_graph->set_return(return_node); | target_func_graph->set_return(return_node); | ||||
| auto &value_nodes = manager_->func_graph_valuenodes()[func_graph]; | |||||
| for (auto &value_node : value_nodes) { | |||||
| CloneValueNode(value_node.first, target_func_graph); | |||||
| auto &cnodes = manager_->func_graph_cnodes_index()[func_graph]; | |||||
| for (auto &cnode : cnodes) { | |||||
| auto parent = cnode.first->first->cast<CNodePtr>(); | |||||
| auto valuenode = parent->input(cnode.first->second); | |||||
| CloneValueNode(valuenode, target_func_graph); | |||||
| } | } | ||||
| } | } | ||||
| @@ -386,8 +388,8 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph | |||||
| if (lift_params.empty()) { | if (lift_params.empty()) { | ||||
| return; | return; | ||||
| } | } | ||||
| for (auto &user : func_graph_user->func_graph_users()) { | |||||
| LiftParameters(user.first, func_graph_user, lift_params); | |||||
| for (auto &cnode : func_graph_user->func_graph_cnodes_index()) { | |||||
| LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params); | |||||
| } | } | ||||
| } | } | ||||
| @@ -395,8 +397,8 @@ void Cloner::Lift() { | |||||
| for (auto &func_graph_params : repl_func_graph_params_) { | for (auto &func_graph_params : repl_func_graph_params_) { | ||||
| auto &func_graph = func_graph_params.first; | auto &func_graph = func_graph_params.first; | ||||
| auto ¶ms = func_graph_params.second; | auto ¶ms = func_graph_params.second; | ||||
| for (auto &user : func_graph->func_graph_users()) { | |||||
| LiftParameters(user.first, func_graph, params); | |||||
| for (auto &cnode : func_graph->func_graph_cnodes_index()) { | |||||
| LiftParameters(cnode.first->first->func_graph(), func_graph, params); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -78,13 +78,16 @@ void FuncGraphManager::Reset() { | |||||
| node_users_ = NodeUsersMap(); | node_users_ = NodeUsersMap(); | ||||
| signals_ = std::make_shared<Signals>(); | signals_ = std::make_shared<Signals>(); | ||||
| // FuncGraph --> AnfNode | |||||
| nodes_ = std::make_shared<NodesCollector>(this); | nodes_ = std::make_shared<NodesCollector>(this); | ||||
| // FuncGraph --> {AnfNode, Count} | |||||
| valuenodes_ = std::make_shared<ValueNodesCollector>(this); | valuenodes_ = std::make_shared<ValueNodesCollector>(this); | ||||
| free_variables_direct_ = std::make_shared<FVDirectCollector>(this); | free_variables_direct_ = std::make_shared<FVDirectCollector>(this); | ||||
| func_graph_valuenodes_ = std::make_shared<FuncGraphValueNodesCollector>(this); | |||||
| func_graph_cnodes_index_ = std::make_shared<FuncGraphUsersCNodeIndexCollector>(this); | |||||
| // FuncGraph --> {FuncGraph, Count} | |||||
| func_graphs_used_ = std::make_shared<FuncGraphsUsedCollector>(this); | func_graphs_used_ = std::make_shared<FuncGraphsUsedCollector>(this); | ||||
| func_graph_users_ = std::make_shared<FuncGraphUsersCollector>(this); | |||||
| func_graph_user_cnodes_ = std::make_shared<FuncGraphUserNodesCollector>(this); | |||||
| func_graph_child_direct_ = std::make_shared<FuncGraphChildDirect>(this); | func_graph_child_direct_ = std::make_shared<FuncGraphChildDirect>(this); | ||||
| func_graph_parents_direct_ = std::make_shared<FuncGraphParentsDirectCollector>(this); | func_graph_parents_direct_ = std::make_shared<FuncGraphParentsDirectCollector>(this); | ||||
| func_graph_j_direct_ = std::make_shared<FuncGraphJDirectCollector>(this); | func_graph_j_direct_ = std::make_shared<FuncGraphJDirectCollector>(this); | ||||
| @@ -300,9 +303,9 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool | |||||
| MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString(); | MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString(); | ||||
| continue; | continue; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(func_graph_users_); | |||||
| auto &users = func_graph_users_->count_func_graphs_map()[func_graph]; | |||||
| if (!users.empty() && !ignore_users) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph_cnodes_index_); | |||||
| auto &users_cnode_index = func_graph_cnodes_index_->count_nodes_map()[func_graph]; | |||||
| if (!users_cnode_index.empty() && !ignore_users) { | |||||
| MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); | MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -472,10 +475,6 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t | |||||
| node->set_scope(scope); | node->set_scope(scope); | ||||
| } | } | ||||
| } | } | ||||
| for (auto &used : source->func_graphs_used()) { | |||||
| (void)func_graph_users_->Inc(used.first, target, used.second); | |||||
| (void)this->func_graph_users()[used.first].erase(source); | |||||
| } | |||||
| for (auto &child : this->func_graph_child_direct()[source]) { | for (auto &child : this->func_graph_child_direct()[source]) { | ||||
| (void)func_graph_parents_direct_->Inc(child.first, target, child.second); | (void)func_graph_parents_direct_->Inc(child.first, target, child.second); | ||||
| (void)this->func_graph_parents_direct()[child.first].erase(source); | (void)this->func_graph_parents_direct()[child.first].erase(source); | ||||
| @@ -661,7 +660,9 @@ DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAna | |||||
| void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } | void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } | ||||
| bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { | |||||
| template <typename ValueT, class CollectorHash, class CollectorEqual> | |||||
| bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Inc(const FuncGraphPtr &func_graph, | |||||
| const ValueT &key, int count) { | |||||
| auto &d = count_nodes_map_[func_graph]; | auto &d = count_nodes_map_[func_graph]; | ||||
| if (d.count(key) == 0) { | if (d.count(key) == 0) { | ||||
| d[key] = count; | d[key] = count; | ||||
| @@ -672,7 +673,9 @@ bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodeP | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { | |||||
| template <typename ValueT, class CollectorHash, class CollectorEqual> | |||||
| bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Dec(const FuncGraphPtr &func_graph, | |||||
| const ValueT &key, int count) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| auto &d = count_nodes_map_[func_graph]; | auto &d = count_nodes_map_[func_graph]; | ||||
| if (d.count(key) != 0) { | if (d.count(key) != 0) { | ||||
| @@ -682,7 +685,7 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodeP | |||||
| } else { | } else { | ||||
| d[key] -= count; | d[key] -= count; | ||||
| if (d[key] < 0) { | if (d[key] < 0) { | ||||
| MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() | |||||
| MS_LOG(EXCEPTION) << "Count of key '" << key | |||||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | ||||
| } | } | ||||
| } | } | ||||
| @@ -690,52 +693,15 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodeP | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count) { | |||||
| template <typename ValueT, class CollectorHash, class CollectorEqual> | |||||
| bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Mod(const FuncGraphPtr &func_graph, | |||||
| const ValueT &key, int count) { | |||||
| if (count > 0) { | if (count > 0) { | ||||
| return Inc(func_graph, key, count); | return Inc(func_graph, key, count); | ||||
| } else if (count < 0) { | } else if (count < 0) { | ||||
| return Dec(func_graph, key, -count); | return Dec(func_graph, key, -count); | ||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() | |||||
| << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||||
| } | |||||
| } | |||||
| bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { | |||||
| auto &d = count_func_graphs_map_[func_graph]; | |||||
| if (d.count(key) == 0) { | |||||
| d[key] = count; | |||||
| return true; | |||||
| } else { | |||||
| d[key] += count; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { | |||||
| auto &d = count_func_graphs_map_[func_graph]; | |||||
| if (d.count(key) != 0) { | |||||
| if (d[key] == count) { | |||||
| (void)d.erase(key); | |||||
| return true; | |||||
| } else { | |||||
| d[key] -= count; | |||||
| if (d[key] < 0) { | |||||
| MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() | |||||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||||
| } | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) { | |||||
| if (count > 0) { | |||||
| return Inc(func_graph, key, count); | |||||
| } else if (count < 0) { | |||||
| return Dec(func_graph, key, -count); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() | |||||
| MS_LOG(EXCEPTION) << "Count of key '" << key | |||||
| << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | ||||
| } | } | ||||
| } | } | ||||
| @@ -754,16 +720,21 @@ void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||||
| (void)count_nodes_map_.erase(src); | (void)count_nodes_map_.erase(src); | ||||
| } | } | ||||
| // if inp is a graph ValueNode, this graph's FuncGraphValueNodesCollector's value is inp self | |||||
| void FuncGraphValueNodesCollector::OnModEdge(AnfNodePtr, int, AnfNodePtr inp, EdgeProcessDirection direction) { | |||||
| void FuncGraphUsersCNodeIndexCollector::OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, | |||||
| EdgeProcessDirection direction) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (IsValueNode<FuncGraph>(inp)) { | if (IsValueNode<FuncGraph>(inp)) { | ||||
| (void)Mod(GetValueNode<FuncGraphPtr>(inp), inp, direction); | |||||
| (void)Mod(GetValueNode<FuncGraphPtr>(inp), std::make_shared<CNodeIndexPair>(std::make_pair(node, index)), | |||||
| direction); | |||||
| } | } | ||||
| } | } | ||||
| void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||||
| void FuncGraphUsersCNodeIndexCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||||
| for (auto &it : count_nodes_map_[src]) { | for (auto &it : count_nodes_map_[src]) { | ||||
| (void)Inc(dst, it.first, it.second); | |||||
| // Ignore the user graph who may own itself. | |||||
| if (dst != it.first->first->func_graph()) { | |||||
| (void)Inc(dst, it.first, it.second); | |||||
| } | |||||
| } | } | ||||
| (void)count_nodes_map_.erase(src); | (void)count_nodes_map_.erase(src); | ||||
| } | } | ||||
| @@ -794,6 +765,45 @@ static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) { | |||||
| return gn; | return gn; | ||||
| } | } | ||||
| bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { | |||||
| auto &d = count_func_graphs_map_[func_graph]; | |||||
| if (d.count(key) == 0) { | |||||
| d[key] = count; | |||||
| return true; | |||||
| } else { | |||||
| d[key] += count; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { | |||||
| auto &d = count_func_graphs_map_[func_graph]; | |||||
| if (d.count(key) != 0) { | |||||
| if (d[key] == count) { | |||||
| (void)d.erase(key); | |||||
| return true; | |||||
| } else { | |||||
| d[key] -= count; | |||||
| if (d[key] < 0) { | |||||
| MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() | |||||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||||
| } | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) { | |||||
| if (count > 0) { | |||||
| return Inc(func_graph, key, count); | |||||
| } else if (count < 0) { | |||||
| return Dec(func_graph, key, -count); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() | |||||
| << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||||
| } | |||||
| } | |||||
| void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { | void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(inp); | MS_EXCEPTION_IF_NULL(inp); | ||||
| @@ -859,32 +869,6 @@ void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) | |||||
| (void)count_func_graphs_map_.erase(src); | (void)count_func_graphs_map_.erase(src); | ||||
| } | } | ||||
| void FuncGraphUsersCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (IsValueNode<FuncGraph>(inp)) { | |||||
| (void)Mod(GetValueNode<FuncGraphPtr>(inp), node->func_graph(), direction); | |||||
| } | |||||
| } | |||||
| void FuncGraphUsersCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr) { | |||||
| // all graph use in src need to change to dst, so add dst user | |||||
| (void)count_func_graphs_map_.erase(src); | |||||
| } | |||||
| void FuncGraphUserNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (IsValueNode<FuncGraph>(inp)) { | |||||
| (void)Mod(GetValueNode<FuncGraphPtr>(inp), node, direction); | |||||
| } | |||||
| } | |||||
| void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||||
| for (auto &it : count_nodes_map_[src]) { | |||||
| (void)Inc(dst, it.first, it.second); | |||||
| } | |||||
| (void)count_nodes_map_.erase(src); | |||||
| } | |||||
| void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { | void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { | ||||
| if (IsValueNode<FuncGraph>(inp) && IsPrimitiveCNode(node, prim::kPrimJ)) { | if (IsValueNode<FuncGraph>(inp) && IsPrimitiveCNode(node, prim::kPrimJ)) { | ||||
| (void)Mod(node->func_graph(), GetValueNode<FuncGraphPtr>(inp), direction); | (void)Mod(node->func_graph(), GetValueNode<FuncGraphPtr>(inp), direction); | ||||
| @@ -100,8 +100,12 @@ struct Signals { | |||||
| enum EdgeProcessDirection { kDecEdge = -1, kIncEdge = 1 }; | enum EdgeProcessDirection { kDecEdge = -1, kIncEdge = 1 }; | ||||
| using CNodeIndexPair = std::pair<AnfNodePtr, int>; | |||||
| using CNodeIndexPairPtr = std::shared_ptr<CNodeIndexPair>; | |||||
| using FuncGraphToFuncGraphCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<FuncGraphPtr, int>>; | using FuncGraphToFuncGraphCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<FuncGraphPtr, int>>; | ||||
| using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<AnfNodePtr, int>>; | |||||
| template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>> | |||||
| using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<ValueT, int, CollectorHash, CollectorEqual>>; | |||||
| // analysis base class | // analysis base class | ||||
| class FuncGraphAnalysis { | class FuncGraphAnalysis { | ||||
| @@ -174,46 +178,56 @@ class NodesCollector final : public DepCollector { | |||||
| void OnDropNode(AnfNodePtr n) override; | void OnDropNode(AnfNodePtr n) override; | ||||
| }; | }; | ||||
| class CounterFuncGraphCollector : public DepCollector { | |||||
| public: | |||||
| explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {} | |||||
| ~CounterFuncGraphCollector() override = default; | |||||
| FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; } | |||||
| // inherit from FuncGraphAnalysis | |||||
| size_t size() const override { return count_func_graphs_map_.size(); } | |||||
| void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); } | |||||
| void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); } | |||||
| bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||||
| bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||||
| bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||||
| FuncGraphToFuncGraphCounterMap count_func_graphs_map_; | |||||
| struct CNodeIndexHasher { | |||||
| std::size_t operator()(const CNodeIndexPairPtr pair) const { | |||||
| MS_EXCEPTION_IF_NULL(pair); | |||||
| MS_EXCEPTION_IF_NULL(pair->first); | |||||
| return hash_combine(pair->first->hash(), std::hash<int>()(pair->second)); | |||||
| } | |||||
| }; | |||||
| protected: | |||||
| void ExtraReset() override { count_func_graphs_map_.clear(); } | |||||
| struct CNodeIndexEqual { | |||||
| bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const { | |||||
| if (lhs == nullptr || rhs == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if (lhs == rhs) { | |||||
| return true; | |||||
| } | |||||
| if (lhs->first != rhs->first) { | |||||
| return false; | |||||
| } | |||||
| if (lhs->second != rhs->second) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| }; | }; | ||||
| template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>> | |||||
| class CounterAnfNodeCollector : public DepCollector { | class CounterAnfNodeCollector : public DepCollector { | ||||
| public: | public: | ||||
| explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} | explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} | ||||
| ~CounterAnfNodeCollector() override = default; | ~CounterAnfNodeCollector() override = default; | ||||
| FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; } | |||||
| FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; } | |||||
| size_t size() const override { return count_nodes_map_.size(); } | size_t size() const override { return count_nodes_map_.size(); } | ||||
| void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); } | |||||
| void OnAddFuncGraph(FuncGraphPtr fg) final { | |||||
| count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>(); | |||||
| } | |||||
| void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } | void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } | ||||
| bool Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); | |||||
| bool Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); | |||||
| bool Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); | |||||
| bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count); | |||||
| bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count); | |||||
| bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count); | |||||
| FuncGraphToAnfNodeCounterMap count_nodes_map_; | |||||
| FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_; | |||||
| protected: | protected: | ||||
| void ExtraReset() override { count_nodes_map_.clear(); } | void ExtraReset() override { count_nodes_map_.clear(); } | ||||
| }; | }; | ||||
| class ValueNodesCollector final : public CounterAnfNodeCollector { | |||||
| class ValueNodesCollector final : public CounterAnfNodeCollector<AnfNodePtr> { | |||||
| public: | public: | ||||
| explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | ||||
| ~ValueNodesCollector() override = default; | ~ValueNodesCollector() override = default; | ||||
| @@ -223,17 +237,19 @@ class ValueNodesCollector final : public CounterAnfNodeCollector { | |||||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | ||||
| }; | }; | ||||
| class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector { | |||||
| // Record the CNode and its input index, who points to the function graph. | |||||
| class FuncGraphUsersCNodeIndexCollector final | |||||
| : public CounterAnfNodeCollector<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> { | |||||
| public: | public: | ||||
| explicit FuncGraphValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||||
| ~FuncGraphValueNodesCollector() override = default; | |||||
| explicit FuncGraphUsersCNodeIndexCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||||
| ~FuncGraphUsersCNodeIndexCollector() override = default; | |||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | ||||
| protected: | protected: | ||||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | ||||
| }; | }; | ||||
| class FVDirectCollector final : public CounterAnfNodeCollector { | |||||
| class FVDirectCollector final : public CounterAnfNodeCollector<AnfNodePtr> { | |||||
| public: | public: | ||||
| explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | ||||
| ~FVDirectCollector() override = default; | ~FVDirectCollector() override = default; | ||||
| @@ -243,6 +259,25 @@ class FVDirectCollector final : public CounterAnfNodeCollector { | |||||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | ||||
| }; | }; | ||||
| class CounterFuncGraphCollector : public DepCollector { | |||||
| public: | |||||
| explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {} | |||||
| ~CounterFuncGraphCollector() override = default; | |||||
| FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; } | |||||
| // inherit from FuncGraphAnalysis | |||||
| size_t size() const override { return count_func_graphs_map_.size(); } | |||||
| void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); } | |||||
| void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); } | |||||
| bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||||
| bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||||
| bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||||
| FuncGraphToFuncGraphCounterMap count_func_graphs_map_; | |||||
| protected: | |||||
| void ExtraReset() override { count_func_graphs_map_.clear(); } | |||||
| }; | |||||
| class FuncGraphChildDirect final : public CounterFuncGraphCollector { | class FuncGraphChildDirect final : public CounterFuncGraphCollector { | ||||
| public: | public: | ||||
| explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | ||||
| @@ -279,28 +314,6 @@ class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { | |||||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | ||||
| }; | }; | ||||
| // graph's all user graphs: key is g, value is graphs who used g | |||||
| class FuncGraphUsersCollector final : public CounterFuncGraphCollector { | |||||
| public: | |||||
| explicit FuncGraphUsersCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||||
| ~FuncGraphUsersCollector() override = default; | |||||
| protected: | |||||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | |||||
| }; | |||||
| // graph's all user cnodes: key is g, value is cnodes who used g | |||||
| class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector { | |||||
| public: | |||||
| explicit FuncGraphUserNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||||
| ~FuncGraphUserNodesCollector() override = default; | |||||
| protected: | |||||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | |||||
| }; | |||||
| class FuncGraphJDirectCollector final : public CounterFuncGraphCollector { | class FuncGraphJDirectCollector final : public CounterFuncGraphCollector { | ||||
| public: | public: | ||||
| explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | ||||
| @@ -433,7 +446,9 @@ class ScopeComputer final : public DepComputer { | |||||
| using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash>>; | using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash>>; | ||||
| class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector, public CounterFuncGraphCollector { | |||||
| class FVTotalComputer final : public DepComputer, | |||||
| public CounterAnfNodeCollector<AnfNodePtr>, | |||||
| public CounterFuncGraphCollector { | |||||
| public: | public: | ||||
| explicit FVTotalComputer(const FuncGraphManager *m) | explicit FVTotalComputer(const FuncGraphManager *m) | ||||
| : DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {} | : DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {} | ||||
| @@ -549,18 +564,18 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||||
| FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; } | FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; } | ||||
| FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; } | |||||
| FuncGraphToAnfNodeCounterMap<AnfNodePtr> &valuenodes() const { return valuenodes_->count_nodes_map_; } | |||||
| FuncGraphToAnfNodeCounterMap &free_variables_direct() const { return free_variables_direct_->count_nodes_map_; } | |||||
| FuncGraphToAnfNodeCounterMap<AnfNodePtr> &free_variables_direct() const { | |||||
| return free_variables_direct_->count_nodes_map_; | |||||
| } | |||||
| FuncGraphToAnfNodeCounterMap &func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; } | |||||
| FuncGraphToAnfNodeCounterMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> &func_graph_cnodes_index() const { | |||||
| return func_graph_cnodes_index_->count_nodes_map_; | |||||
| } | |||||
| FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } | FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } | ||||
| FuncGraphToFuncGraphCounterMap &func_graph_users() const { return func_graph_users_->count_func_graphs_map_; } | |||||
| FuncGraphToAnfNodeCounterMap &func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; } | |||||
| FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const { | FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const { | ||||
| return func_graph_child_direct_->count_func_graphs_map_; | return func_graph_child_direct_->count_func_graphs_map_; | ||||
| } | } | ||||
| @@ -598,10 +613,8 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||||
| std::shared_ptr<NodesCollector> nodes_; | std::shared_ptr<NodesCollector> nodes_; | ||||
| std::shared_ptr<ValueNodesCollector> valuenodes_; | std::shared_ptr<ValueNodesCollector> valuenodes_; | ||||
| std::shared_ptr<FVDirectCollector> free_variables_direct_; | std::shared_ptr<FVDirectCollector> free_variables_direct_; | ||||
| std::shared_ptr<FuncGraphValueNodesCollector> func_graph_valuenodes_; | |||||
| std::shared_ptr<FuncGraphUsersCNodeIndexCollector> func_graph_cnodes_index_; | |||||
| std::shared_ptr<FuncGraphsUsedCollector> func_graphs_used_; | std::shared_ptr<FuncGraphsUsedCollector> func_graphs_used_; | ||||
| std::shared_ptr<FuncGraphUsersCollector> func_graph_users_; | |||||
| std::shared_ptr<FuncGraphUserNodesCollector> func_graph_user_cnodes_; | |||||
| std::shared_ptr<FuncGraphChildDirect> func_graph_child_direct_; | std::shared_ptr<FuncGraphChildDirect> func_graph_child_direct_; | ||||
| std::shared_ptr<FuncGraphParentsDirectCollector> func_graph_parents_direct_; | std::shared_ptr<FuncGraphParentsDirectCollector> func_graph_parents_direct_; | ||||
| std::shared_ptr<FuncGraphJDirectCollector> func_graph_j_direct_; | std::shared_ptr<FuncGraphJDirectCollector> func_graph_j_direct_; | ||||
| @@ -81,10 +81,10 @@ bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) { | |||||
| } | } | ||||
| bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) { | bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) { | ||||
| auto &users = fg->func_graph_users(); | |||||
| auto &cnodes = fg->func_graph_cnodes_index(); | |||||
| int n_use = | int n_use = | ||||
| std::accumulate(users.begin(), users.end(), 0, | |||||
| [](int sum, const std::pair<const FuncGraphPtr, int> &item) { return sum + item.second; }); | |||||
| std::accumulate(cnodes.begin(), cnodes.end(), 0, | |||||
| [](int sum, const std::pair<const CNodeIndexPairPtr, int> &item) { return sum + item.second; }); | |||||
| return n_use == 1; | return n_use == 1; | ||||
| } | } | ||||
| @@ -486,7 +486,8 @@ void CompileGraph::AddExternal(const LinConvertResult &result) { | |||||
| } | } | ||||
| void TraverseGraphMap( | void TraverseGraphMap( | ||||
| const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphToAnfNodeCounterMap &cts, | |||||
| const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, | |||||
| const FuncGraphToAnfNodeCounterMap<AnfNodePtr> &cts, | |||||
| const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { | const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { | ||||
| MS_EXCEPTION_IF_NULL(manager_ptr); | MS_EXCEPTION_IF_NULL(manager_ptr); | ||||
| MS_EXCEPTION_IF_NULL(tr); | MS_EXCEPTION_IF_NULL(tr); | ||||