Merge pull request !1359 from ZhangQinghua/mastertags/v0.3.0-alpha
| @@ -29,6 +29,7 @@ | |||||
| #include "utils/visible.h" | #include "utils/visible.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "utils/ordered_set.h" | #include "utils/ordered_set.h" | ||||
| #include "utils/ordered_map.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| template <typename T> | template <typename T> | ||||
| @@ -47,6 +47,7 @@ FuncGraph::FuncGraph() | |||||
| : flags_(), | : flags_(), | ||||
| transforms_(), | transforms_(), | ||||
| parameter_default_value_(), | parameter_default_value_(), | ||||
| seen_(0), | |||||
| parameters_(), | parameters_(), | ||||
| has_vararg_(false), | has_vararg_(false), | ||||
| has_kwarg_(false), | has_kwarg_(false), | ||||
| @@ -195,25 +196,93 @@ GraphDebugInfoPtr FuncGraph::debug_info() { | |||||
| return this->debug_info_; | return this->debug_info_; | ||||
| } | } | ||||
| const AnfNodeSet &FuncGraph::nodes() { | |||||
| auto mng = manager_.lock(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| auto &nodes = mng->nodes(); | |||||
| return nodes[shared_from_base<FuncGraph>()]; | |||||
| const AnfNodeSet &FuncGraph::nodes() { return nodes_; } | |||||
| void FuncGraph::CopyNodes(const FuncGraphPtr &source) { nodes_ = source->nodes(); } | |||||
| void FuncGraph::ClearNodes() { nodes_.clear(); } | |||||
| void FuncGraph::AddNode(AnfNodePtr node) { nodes_.add(node); } | |||||
| void FuncGraph::DropNode(AnfNodePtr node) { | |||||
| nodes_.erase(node); | |||||
| auto graph = node->func_graph(); | |||||
| // Remove the node from order list. | |||||
| if (graph) { | |||||
| graph->EraseUnusedNodeInOrder(node); | |||||
| } | |||||
| } | } | ||||
| const AnfNodeCounterMap &FuncGraph::value_nodes() { | |||||
| auto mng = manager_.lock(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| auto &cts = mng->valuenodes(); | |||||
| return cts[shared_from_base<FuncGraph>()]; | |||||
| const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; } | |||||
| void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) { | |||||
| auto &others = source->value_nodes(); | |||||
| for (auto it = others.begin(); it != others.end(); it++) { | |||||
| AddValueNode(it->first, it->second); | |||||
| } | |||||
| } | } | ||||
| const AnfNodeCounterMap &FuncGraph::free_variables_direct() { | |||||
| auto mng = manager_.lock(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| auto &fv_direct = mng->free_variables_direct(); | |||||
| return fv_direct[shared_from_base<FuncGraph>()]; | |||||
| void FuncGraph::ClearValueNodes() { value_nodes_.clear(); } | |||||
| void FuncGraph::AddValueNode(AnfNodePtr node, int count) { | |||||
| if (value_nodes_.count(node) == 0) { | |||||
| value_nodes_[node] = count; | |||||
| } else { | |||||
| value_nodes_[node] += count; | |||||
| } | |||||
| } | |||||
| void FuncGraph::DropValueNode(AnfNodePtr node) { | |||||
| if (value_nodes_.count(node) != 0) { | |||||
| if (value_nodes_[node] == 1) { | |||||
| (void)value_nodes_.erase(node); | |||||
| } else { | |||||
| value_nodes_[node]--; | |||||
| if (value_nodes_[node] < 0) { | |||||
| MS_LOG(EXCEPTION) << "Count of ValueNode '" << node | |||||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; } | |||||
| void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) { | |||||
| auto &others = source->free_variables(); | |||||
| for (auto it = others.begin(); it != others.end(); it++) { | |||||
| if (it->first->func_graph().get() != this) { | |||||
| (void)AddFreeVariable(it->first, it->second); | |||||
| } | |||||
| } | |||||
| } | |||||
| void FuncGraph::ClearFreeVariables() { free_variables_.clear(); } | |||||
| bool FuncGraph::AddFreeVariable(AnfNodePtr node, int count) { | |||||
| if (free_variables_.count(node) == 0) { | |||||
| free_variables_[node] = count; | |||||
| return true; | |||||
| } else { | |||||
| free_variables_[node] += count; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| bool FuncGraph::DropFreeVariable(AnfNodePtr node) { | |||||
| if (free_variables_.count(node) != 0) { | |||||
| if (free_variables_[node] == 1) { | |||||
| (void)free_variables_.erase(node); | |||||
| return true; | |||||
| } else { | |||||
| free_variables_[node]--; | |||||
| if (free_variables_[node] < 0) { | |||||
| MS_LOG(EXCEPTION) << "Count of free variable '" << node | |||||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); | |||||
| } | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | } | ||||
| const BaseRefCounterMap &FuncGraph::free_variables_total() { | const BaseRefCounterMap &FuncGraph::free_variables_total() { | ||||
| @@ -249,11 +318,42 @@ std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() { | |||||
| return func_graphs; | return func_graphs; | ||||
| } | } | ||||
| const FuncGraphCounterMap &FuncGraph::func_graphs_used() { | |||||
| auto mng = manager_.lock(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| auto &used = mng->func_graphs_used(); | |||||
| return used[shared_from_base<FuncGraph>()]; | |||||
| const FuncGraphCounterMap &FuncGraph::func_graphs_used() { return func_graphs_used_; } | |||||
| void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) { | |||||
| auto &others = source->func_graphs_used(); | |||||
| for (auto it = others.begin(); it != others.end(); it++) { | |||||
| (void)AddFuncGraphUsed(it->first, it->second); | |||||
| } | |||||
| func_graphs_used_.erase(source); | |||||
| } | |||||
| void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); } | |||||
| bool FuncGraph::AddFuncGraphUsed(FuncGraphPtr fg, int count) { | |||||
| if (func_graphs_used_.count(fg) == 0) { | |||||
| func_graphs_used_[fg] = count; | |||||
| return true; | |||||
| } else { | |||||
| func_graphs_used_[fg] += count; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| bool FuncGraph::DropFuncGraphUsed(FuncGraphPtr fg) { | |||||
| if (func_graphs_used_.count(fg) != 0) { | |||||
| if (func_graphs_used_[fg] == 1) { | |||||
| (void)func_graphs_used_.erase(fg); | |||||
| return true; | |||||
| } else { | |||||
| func_graphs_used_[fg]--; | |||||
| if (func_graphs_used_[fg] < 0) { | |||||
| MS_LOG(EXCEPTION) << "Count of FuncGraph '" << fg | |||||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); | |||||
| } | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | } | ||||
| const FuncGraphSet &FuncGraph::func_graphs_used_total() { | const FuncGraphSet &FuncGraph::func_graphs_used_total() { | ||||
| @@ -263,15 +363,75 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() { | |||||
| return used; | return used; | ||||
| } | } | ||||
| const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { | |||||
| auto mng = manager_.lock(); | |||||
| if (mng == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString() | |||||
| << " NodeInfo: " << trace::GetDebugInfo(debug_info()); | |||||
| const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; } | |||||
| void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) { | |||||
| auto &others = source->func_graph_cnodes_index(); | |||||
| for (auto it = others.begin(); it != others.end(); it++) { | |||||
| // Ignore the user graph who may own itself. | |||||
| auto fg = it->first->first->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(fg); | |||||
| if (fg.get() != this) { | |||||
| AddFuncGraphCNodeIndex(it->first, it->second); | |||||
| } | |||||
| } | |||||
| } | |||||
| void FuncGraph::ClearFuncGraphCNodesIndex() { func_graph_cnodes_index_.clear(); } | |||||
| void FuncGraph::AddFuncGraphCNodeIndex(CNodeIndexPairPtr pair, int count) { | |||||
| if (func_graph_cnodes_index_.count(pair) == 0) { | |||||
| func_graph_cnodes_index_[pair] = count; | |||||
| } else { | |||||
| func_graph_cnodes_index_[pair] += count; | |||||
| } | |||||
| } | |||||
| void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) { | |||||
| if (func_graph_cnodes_index_.count(pair) != 0) { | |||||
| if (func_graph_cnodes_index_[pair] == 1) { | |||||
| (void)func_graph_cnodes_index_.erase(pair); | |||||
| } else { | |||||
| func_graph_cnodes_index_[pair]--; | |||||
| if (func_graph_cnodes_index_[pair] < 0) { | |||||
| MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second | |||||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| const FuncGraphCounterMap &FuncGraph::j_func_graphs() { return j_func_graphs_; } | |||||
| void FuncGraph::CopyJFuncGraphs(const FuncGraphPtr &source) { | |||||
| auto &others = source->j_func_graphs(); | |||||
| for (auto it = others.begin(); it != others.end(); it++) { | |||||
| AddJFuncGraph(it->first, it->second); | |||||
| } | |||||
| } | |||||
| void FuncGraph::ClearJFuncGraphs() { j_func_graphs_.clear(); } | |||||
| void FuncGraph::AddJFuncGraph(FuncGraphPtr fg, int count) { | |||||
| if (j_func_graphs_.count(fg) == 0) { | |||||
| j_func_graphs_[fg] = count; | |||||
| } else { | |||||
| j_func_graphs_[fg] += count; | |||||
| } | |||||
| } | |||||
| void FuncGraph::DropJFuncGraph(FuncGraphPtr fg) { | |||||
| if (j_func_graphs_.count(fg) != 0) { | |||||
| if (j_func_graphs_[fg] == 1) { | |||||
| (void)j_func_graphs_.erase(fg); | |||||
| } else { | |||||
| j_func_graphs_[fg]--; | |||||
| if (j_func_graphs_[fg] < 0) { | |||||
| MS_LOG(EXCEPTION) << "Count of J FuncGraph '" << fg | |||||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| auto &cnode = mng->func_graph_cnodes_index(); | |||||
| return cnode[shared_from_base<FuncGraph>()]; | |||||
| } | } | ||||
| FuncGraphPtr FuncGraph::parent() { | FuncGraphPtr FuncGraph::parent() { | ||||
| @@ -662,10 +822,10 @@ void FuncGraph::EraseUnusedNodeInOrder() { | |||||
| if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | ||||
| auto mng = manager_.lock(); | auto mng = manager_.lock(); | ||||
| if (mng) { | if (mng) { | ||||
| auto nodes = mng->nodes()[shared_from_base<FuncGraph>()]; | |||||
| auto &all_nodes = nodes(); | |||||
| // Erase unused cnode. | // Erase unused cnode. | ||||
| for (auto it = order_.begin(); it != order_.end();) { | for (auto it = order_.begin(); it != order_.end();) { | ||||
| if (nodes.count(*it)) { | |||||
| if (all_nodes.count(*it)) { | |||||
| (void)it++; | (void)it++; | ||||
| } else { | } else { | ||||
| MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order."; | MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order."; | ||||
| @@ -702,11 +862,11 @@ void FuncGraph::CheckOrder() { | |||||
| } | } | ||||
| auto mng = manager_.lock(); | auto mng = manager_.lock(); | ||||
| if (mng != nullptr) { | if (mng != nullptr) { | ||||
| const auto &nodes = mng->nodes()[shared_from_base<FuncGraph>()]; | |||||
| if (nodes.size() != (order_.size() + parameters_.size())) { | |||||
| const auto &all_nodes = nodes(); | |||||
| if (all_nodes.size() != (order_.size() + parameters_.size())) { | |||||
| DumpCNodeList(); | DumpCNodeList(); | ||||
| MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " | MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " | ||||
| << nodes.size() - parameters_.size() << "."; | |||||
| << all_nodes.size() - parameters_.size() << "."; | |||||
| } | } | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "Check order okay."; | MS_LOG(DEBUG) << "Check order okay."; | ||||
| @@ -840,6 +1000,11 @@ void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs) { | |||||
| } | } | ||||
| } | } | ||||
| size_t NewFgSeenGeneration() { | |||||
| static size_t fg_seen_generation = 0; | |||||
| return ++fg_seen_generation; | |||||
| } | |||||
| const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared<Primitive>("FuncGraph"); | const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared<Primitive>("FuncGraph"); | ||||
| const char kFuncGraphFlagUndetermined[] = "Undeterminate"; | const char kFuncGraphFlagUndetermined[] = "Undeterminate"; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,6 +26,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <functional> | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| @@ -36,8 +37,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>; | using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>; | ||||
| using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>; | using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>; | ||||
| using AnfNodeCounterMap = OrderedMap<AnfNodePtr, int>; | |||||
| using CNodeIndexCounterMap = OrderedMap<CNodeIndexPairPtr, int, CNodeIndexHasher, CNodeIndexEqual>; | |||||
| template <typename ValueT, class CounterHash = std::hash<ValueT>, class CounterEqual = std::equal_to<ValueT>> | |||||
| using CounterOrderedMap = OrderedMap<ValueT, int, CounterHash, CounterEqual>; | |||||
| using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>; | |||||
| using CNodeIndexCounterMap = CounterOrderedMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual>; | |||||
| using FuncGraphMap = OrderedMap<FuncGraphPtr, int>; | |||||
| const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; | const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; | ||||
| const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; | const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; | ||||
| @@ -183,12 +189,24 @@ class FuncGraph : public FuncGraphBase { | |||||
| // get all nodes belonging to this func graph | // get all nodes belonging to this func graph | ||||
| const AnfNodeSet &nodes(); | const AnfNodeSet &nodes(); | ||||
| void CopyNodes(const FuncGraphPtr &source); | |||||
| void ClearNodes(); | |||||
| void AddNode(AnfNodePtr node); | |||||
| void DropNode(AnfNodePtr node); | |||||
| // get all value_nodes belonging to this func graph | // get all value_nodes belonging to this func graph | ||||
| const AnfNodeCounterMap &value_nodes(); | const AnfNodeCounterMap &value_nodes(); | ||||
| // get all vars directly pointed to in this func graph | |||||
| const AnfNodeCounterMap &free_variables_direct(); | |||||
| void CopyValueNodes(const FuncGraphPtr &source); | |||||
| void ClearValueNodes(); | |||||
| void AddValueNode(AnfNodePtr node, int count = 1); | |||||
| void DropValueNode(AnfNodePtr node); | |||||
| // get all free vars directly used in this func graph | |||||
| const AnfNodeCounterMap &free_variables(); | |||||
| void CopyFreeVariables(const FuncGraphPtr &source); | |||||
| void ClearFreeVariables(); | |||||
| bool AddFreeVariable(AnfNodePtr node, int count = 1); | |||||
| bool DropFreeVariable(AnfNodePtr node); | |||||
| // get all vars required by this func graph | // get all vars required by this func graph | ||||
| const BaseRefCounterMap &free_variables_total(); | const BaseRefCounterMap &free_variables_total(); | ||||
| @@ -199,14 +217,29 @@ class FuncGraph : public FuncGraphBase { | |||||
| // get all vars that are func graphs | // get all vars that are func graphs | ||||
| std::vector<FuncGraphPtr> free_variables_func_graphs(); | std::vector<FuncGraphPtr> free_variables_func_graphs(); | ||||
| // get all func graphs directly used by this func graph | |||||
| // get all value nodes of func graph directly used by this func graph | |||||
| const FuncGraphCounterMap &func_graphs_used(); | const FuncGraphCounterMap &func_graphs_used(); | ||||
| void CopyFuncGraphsUsed(const FuncGraphPtr &source); | |||||
| void ClearFuncGraphsUsed(); | |||||
| bool AddFuncGraphUsed(FuncGraphPtr fg, int count = 1); | |||||
| bool DropFuncGraphUsed(FuncGraphPtr fg); | |||||
| // get all value nodes of J func graph directly used by this func graph | |||||
| const FuncGraphCounterMap &j_func_graphs(); | |||||
| void CopyJFuncGraphs(const FuncGraphPtr &source); | |||||
| void ClearJFuncGraphs(); | |||||
| void AddJFuncGraph(FuncGraphPtr fg, int count = 1); | |||||
| void DropJFuncGraph(FuncGraphPtr fg); | |||||
| // get all func graphs nested used by this func graph | // get all func graphs nested used by this func graph | ||||
| const FuncGraphSet &func_graphs_used_total(); | const FuncGraphSet &func_graphs_used_total(); | ||||
| // get all user value nodes of this func graph | |||||
| // get all user value nodes of this func graph, by CNode and its input's index | |||||
| const CNodeIndexCounterMap &func_graph_cnodes_index(); | const CNodeIndexCounterMap &func_graph_cnodes_index(); | ||||
| void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source); | |||||
| void ClearFuncGraphCNodesIndex(); | |||||
| void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1); | |||||
| void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node); | |||||
| // Return the parent of this graph. | // Return the parent of this graph. | ||||
| FuncGraphPtr parent(); | FuncGraphPtr parent(); | ||||
| @@ -256,6 +289,7 @@ class FuncGraph : public FuncGraphBase { | |||||
| // parameter default value | // parameter default value | ||||
| std::map<std::string, AnfNodePtr> parameter_default_value_; | std::map<std::string, AnfNodePtr> parameter_default_value_; | ||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_; | std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_; | ||||
| size_t seen_; | |||||
| std::list<CNodePtr> GetOrderedCnodes(); | std::list<CNodePtr> GetOrderedCnodes(); | ||||
| void EraseUnusedNodeInOrder(const AnfNodePtr &n); | void EraseUnusedNodeInOrder(const AnfNodePtr &n); | ||||
| @@ -270,6 +304,24 @@ class FuncGraph : public FuncGraphBase { | |||||
| // graph is manipulated by manager and others | // graph is manipulated by manager and others | ||||
| friend FuncGraphManager; | friend FuncGraphManager; | ||||
| // all nodes of the function | |||||
| AnfNodeSet nodes_; | |||||
| // all value nodes of the function | |||||
| AnfNodeCounterMap value_nodes_; | |||||
| // all func graph value nodes of the function | |||||
| FuncGraphCounterMap func_graphs_used_; | |||||
| // all free variables of the function | |||||
| AnfNodeCounterMap free_variables_; | |||||
| // all value nodes calling J in the function | |||||
| FuncGraphCounterMap j_func_graphs_; | |||||
| // all user value nodes of this func graph, recording by CNode and its input's index | |||||
| CNodeIndexCounterMap func_graph_cnodes_index_; | |||||
| // parameters of this function | // parameters of this function | ||||
| std::vector<AnfNodePtr> parameters_; | std::vector<AnfNodePtr> parameters_; | ||||
| std::vector<AnfNodePtr> paramter_obj_nodes_; | std::vector<AnfNodePtr> paramter_obj_nodes_; | ||||
| @@ -313,6 +365,8 @@ inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphP | |||||
| return fg->NewCNode(inputs); | return fg->NewCNode(inputs); | ||||
| } | } | ||||
| size_t NewFgSeenGeneration(); | |||||
| // Find the root cnodes of a segment of cnodes. | // Find the root cnodes of a segment of cnodes. | ||||
| std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment); | std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment); | ||||
| // Find the leaf cnodes of a segment of cnodes. | // Find the leaf cnodes of a segment of cnodes. | ||||
| @@ -123,7 +123,7 @@ void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) { | |||||
| if (!clone_all_valuenodes_) { | if (!clone_all_valuenodes_) { | ||||
| return; | return; | ||||
| } | } | ||||
| auto &value_nodes = manager_->valuenodes()[func_graph]; | |||||
| auto &value_nodes = func_graph->value_nodes(); | |||||
| for (auto &value_node : value_nodes) { | for (auto &value_node : value_nodes) { | ||||
| auto old_node = value_node.first; | auto old_node = value_node.first; | ||||
| MS_EXCEPTION_IF_NULL(old_node); | MS_EXCEPTION_IF_NULL(old_node); | ||||
| @@ -153,9 +153,9 @@ void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) { | |||||
| if (!clone_all_used_graphs_) { | if (!clone_all_used_graphs_) { | ||||
| return; | return; | ||||
| } | } | ||||
| auto &used_graphs = manager_->func_graphs_used()[func_graph]; | |||||
| for (auto &used_graph : used_graphs) { | |||||
| todo_.push_back({used_graph.first, nullptr, {}}); | |||||
| auto &used = func_graph->func_graphs_used(); | |||||
| for (auto &fg : used) { | |||||
| todo_.push_back({fg.first, nullptr, {}}); | |||||
| } | } | ||||
| } | } | ||||
| @@ -185,7 +185,7 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func | |||||
| } | } | ||||
| target_func_graph->set_return(return_node); | target_func_graph->set_return(return_node); | ||||
| auto &cnodes = manager_->func_graph_cnodes_index()[func_graph]; | |||||
| auto &cnodes = func_graph->func_graph_cnodes_index(); | |||||
| for (auto &cnode : cnodes) { | for (auto &cnode : cnodes) { | ||||
| auto parent = cnode.first->first->cast<CNodePtr>(); | auto parent = cnode.first->first->cast<CNodePtr>(); | ||||
| auto valuenode = parent->input(cnode.first->second); | auto valuenode = parent->input(cnode.first->second); | ||||
| @@ -441,7 +441,7 @@ void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &t | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(target_func_graph); | MS_EXCEPTION_IF_NULL(target_func_graph); | ||||
| MS_EXCEPTION_IF_NULL(manager_); | MS_EXCEPTION_IF_NULL(manager_); | ||||
| const AnfNodeSet &nodes = manager_->nodes()[func_graph]; | |||||
| const AnfNodeSet &nodes = func_graph->nodes(); | |||||
| for (auto &node : nodes) { | for (auto &node : nodes) { | ||||
| CloneNode(node, target_func_graph); | CloneNode(node, target_func_graph); | ||||
| } | } | ||||
| @@ -78,19 +78,6 @@ void FuncGraphManager::Reset() { | |||||
| node_users_ = NodeUsersMap(); | node_users_ = NodeUsersMap(); | ||||
| signals_ = std::make_shared<Signals>(); | signals_ = std::make_shared<Signals>(); | ||||
| // FuncGraph --> AnfNode | |||||
| nodes_ = std::make_shared<NodesCollector>(this); | |||||
| // FuncGraph --> {AnfNode, Count} | |||||
| valuenodes_ = std::make_shared<ValueNodesCollector>(this); | |||||
| free_variables_direct_ = std::make_shared<FVDirectCollector>(this); | |||||
| func_graph_cnodes_index_ = std::make_shared<FuncGraphUsersCNodeIndexCollector>(this); | |||||
| // FuncGraph --> {FuncGraph, Count} | |||||
| func_graphs_used_ = std::make_shared<FuncGraphsUsedCollector>(this); | |||||
| func_graph_child_direct_ = std::make_shared<FuncGraphChildDirect>(this); | |||||
| func_graph_parents_direct_ = std::make_shared<FuncGraphParentsDirectCollector>(this); | |||||
| func_graph_j_direct_ = std::make_shared<FuncGraphJDirectCollector>(this); | |||||
| func_graph_parents_total_ = std::make_shared<FuncGraphParentsTotalComputer>(this); | func_graph_parents_total_ = std::make_shared<FuncGraphParentsTotalComputer>(this); | ||||
| func_graph_parent_ = std::make_shared<ParentComputer>(this); | func_graph_parent_ = std::make_shared<ParentComputer>(this); | ||||
| @@ -209,8 +196,6 @@ void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) { | |||||
| return; | return; | ||||
| } | } | ||||
| AddIntoManaged(func_graph); | AddIntoManaged(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(signals_); | |||||
| signals_->AddFuncGraph(func_graph); | |||||
| std::vector<AnfNodePtr> para = func_graph->parameters(); | std::vector<AnfNodePtr> para = func_graph->parameters(); | ||||
| AcquireNodes(para); | AcquireNodes(para); | ||||
| std::vector<AnfNodePtr> return_vec({func_graph->get_return()}); | std::vector<AnfNodePtr> return_vec({func_graph->get_return()}); | ||||
| @@ -224,7 +209,6 @@ void FuncGraphManager::Clear() { | |||||
| node_users_.clear(); | node_users_.clear(); | ||||
| roots_.clear(); | roots_.clear(); | ||||
| signals_->InvalidateCollector(); | |||||
| signals_->InvalidateComputer(); | signals_->InvalidateComputer(); | ||||
| } | } | ||||
| @@ -303,8 +287,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool | |||||
| MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString(); | MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString(); | ||||
| continue; | continue; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(func_graph_cnodes_index_); | |||||
| auto &users_cnode_index = func_graph_cnodes_index_->count_nodes_map()[func_graph]; | |||||
| auto &users_cnode_index = func_graph->func_graph_cnodes_index(); | |||||
| if (!users_cnode_index.empty() && !ignore_users) { | if (!users_cnode_index.empty() && !ignore_users) { | ||||
| MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); | MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); | ||||
| continue; | continue; | ||||
| @@ -317,10 +300,8 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool | |||||
| std::vector<AnfNodePtr> return_vec = {func_graph->get_return()}; | std::vector<AnfNodePtr> return_vec = {func_graph->get_return()}; | ||||
| todo.update(MaybeDropNodes(return_vec)); | todo.update(MaybeDropNodes(return_vec)); | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(signals_); | |||||
| for (auto &fg : dropped) { | for (auto &fg : dropped) { | ||||
| MS_EXCEPTION_IF_NULL(fg); | MS_EXCEPTION_IF_NULL(fg); | ||||
| signals_->DropFuncGraph(fg); | |||||
| all_nodes_.difference_update(fg->parameters()); | all_nodes_.difference_update(fg->parameters()); | ||||
| (void)func_graphs_.erase(fg); | (void)func_graphs_.erase(fg); | ||||
| if (fg->manager().get() == this) { | if (fg->manager().get() == this) { | ||||
| @@ -339,7 +320,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E | |||||
| return; | return; | ||||
| } | } | ||||
| (void)users_node.erase(make_pair(node, index)); | (void)users_node.erase(make_pair(node, index)); | ||||
| signals_->DropEdge(node, index, inp); | |||||
| DropEdge(node, index, inp); | |||||
| } else { | } else { | ||||
| MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString(); | MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString(); | ||||
| if (inp->func_graph() != nullptr) { | if (inp->func_graph() != nullptr) { | ||||
| @@ -351,8 +332,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E | |||||
| } | } | ||||
| auto &users_node = node_users_[inp]; | auto &users_node = node_users_[inp]; | ||||
| users_node.add(make_pair(node, index)); | users_node.add(make_pair(node, index)); | ||||
| MS_EXCEPTION_IF_NULL(signals_); | |||||
| signals_->AddEdge(node, index, inp); | |||||
| AddEdge(node, index, inp); | |||||
| } | } | ||||
| } | } | ||||
| @@ -392,8 +372,8 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) { | |||||
| FuncGraphPtr fg = node->func_graph(); | FuncGraphPtr fg = node->func_graph(); | ||||
| if (fg != nullptr) { | if (fg != nullptr) { | ||||
| AddFuncGraph(fg); | AddFuncGraph(fg); | ||||
| fg->AddNode(node); | |||||
| } | } | ||||
| signals_->AddNode(node); | |||||
| ProcessInputs(node, kIncEdge); | ProcessInputs(node, kIncEdge); | ||||
| } | } | ||||
| } | } | ||||
| @@ -401,8 +381,6 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) { | |||||
| FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &nodes) { | FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &nodes) { | ||||
| AnfNodeSet nodes_ordered(nodes); | AnfNodeSet nodes_ordered(nodes); | ||||
| FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>(); | FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>(); | ||||
| MS_EXCEPTION_IF_NULL(signals_); | |||||
| while (!nodes_ordered.empty()) { | while (!nodes_ordered.empty()) { | ||||
| AnfNodePtr node = nodes_ordered.pop(); | AnfNodePtr node = nodes_ordered.pop(); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| @@ -424,7 +402,10 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> & | |||||
| } | } | ||||
| ProcessInputs(node, kDecEdge); | ProcessInputs(node, kDecEdge); | ||||
| (void)all_nodes_.erase(node); | (void)all_nodes_.erase(node); | ||||
| signals_->DropNode(node); | |||||
| if (node->func_graph() != nullptr) { | |||||
| node->func_graph()->DropNode(node); | |||||
| } | |||||
| if (node->isa<CNode>()) { | if (node->isa<CNode>()) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| nodes_ordered.update(cnode->inputs()); | nodes_ordered.update(cnode->inputs()); | ||||
| @@ -462,35 +443,21 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t | |||||
| int index = 0; | int index = 0; | ||||
| (void)node_users_[source_prim].erase(make_pair(source_return, index)); | (void)node_users_[source_prim].erase(make_pair(source_return, index)); | ||||
| signals_->DropEdge(source_return, index, source_prim); | |||||
| DropEdge(source_return, index, source_prim); | |||||
| index = 1; | index = 1; | ||||
| (void)node_users_[source_output].erase(make_pair(source_return, index)); | (void)node_users_[source_output].erase(make_pair(source_return, index)); | ||||
| signals_->DropEdge(source_return, index, source_output); | |||||
| DropEdge(source_return, index, source_output); | |||||
| (void)all_nodes_.erase(source_return); | (void)all_nodes_.erase(source_return); | ||||
| (void)node_users_.erase(source_return); | (void)node_users_.erase(source_return); | ||||
| signals_->DropNode(source_return); | |||||
| source->DropNode(source_return); | |||||
| for (auto &node : source->nodes()) { | for (auto &node : source->nodes()) { | ||||
| node->set_func_graph(target); | node->set_func_graph(target); | ||||
| if (node->scope() == kDefaultScope) { | if (node->scope() == kDefaultScope) { | ||||
| node->set_scope(scope); | node->set_scope(scope); | ||||
| } | } | ||||
| } | } | ||||
| for (auto &child : this->func_graph_child_direct()[source]) { | |||||
| (void)func_graph_parents_direct_->Inc(child.first, target, child.second); | |||||
| (void)this->func_graph_parents_direct()[child.first].erase(source); | |||||
| } | |||||
| for (auto &fv_count : this->free_variables_direct()[source]) { | |||||
| auto fv_g = fv_count.first->func_graph(); | |||||
| auto &count_on_g = this->func_graph_child_direct()[fv_g]; | |||||
| auto pair = count_on_g.find(source); | |||||
| if (fv_g != target && pair != count_on_g.end()) { | |||||
| (void)func_graph_child_direct_->Inc(fv_g, target, pair->second); | |||||
| } | |||||
| (void)count_on_g.erase(source); | |||||
| } | |||||
| signals_->MoveAllCNode(source, target); | |||||
| signals_->InvalidateComputer(); | |||||
| signals_->DropFuncGraph(source); | |||||
| MoveAllNodes(source, target); | |||||
| all_nodes_.difference_update(source->parameters()); | all_nodes_.difference_update(source->parameters()); | ||||
| (void)func_graphs_.erase(source); | (void)func_graphs_.erase(source); | ||||
| if (source->manager().get() == this) { | if (source->manager().get() == this) { | ||||
| @@ -498,6 +465,64 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t | |||||
| } | } | ||||
| } | } | ||||
| inline void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr input) { | |||||
| auto fg = node->func_graph(); | |||||
| if (input->isa<ValueNode>()) { | |||||
| fg->AddValueNode(input); | |||||
| if (IsValueNode<FuncGraph>(input)) { | |||||
| auto used = GetValueNode<FuncGraphPtr>(input); | |||||
| used->AddFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index))); | |||||
| if (fg->AddFuncGraphUsed(used)) { | |||||
| signals_->InvalidateComputer(); | |||||
| } | |||||
| if (IsPrimitiveCNode(node, prim::kPrimJ)) { | |||||
| fg->AddJFuncGraph(used); | |||||
| } | |||||
| } | |||||
| } else if (fg != nullptr && fg != input->func_graph()) { | |||||
| if (fg->AddFreeVariable(input)) { | |||||
| signals_->InvalidateComputer(); | |||||
| } | |||||
| } | |||||
| } | |||||
| inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr input) { | |||||
| auto fg = node->func_graph(); | |||||
| if (input->isa<ValueNode>()) { | |||||
| fg->DropValueNode(input); | |||||
| if (IsValueNode<FuncGraph>(input)) { | |||||
| auto used = GetValueNode<FuncGraphPtr>(input); | |||||
| used->DropFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index))); | |||||
| if (fg->DropFuncGraphUsed(used)) { | |||||
| signals_->InvalidateComputer(); | |||||
| } | |||||
| if (IsPrimitiveCNode(node, prim::kPrimJ)) { | |||||
| fg->DropJFuncGraph(used); | |||||
| } | |||||
| } | |||||
| } else if (fg != nullptr && fg != input->func_graph()) { | |||||
| if (fg->DropFreeVariable(input)) { | |||||
| signals_->InvalidateComputer(); | |||||
| } | |||||
| } | |||||
| } | |||||
| inline void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { | |||||
| target->CopyNodes(source); | |||||
| target->CopyValueNodes(source); | |||||
| target->CopyFuncGraphCNodesIndex(source); | |||||
| target->CopyFreeVariables(source); | |||||
| target->CopyFuncGraphsUsed(source); | |||||
| target->CopyJFuncGraphs(source); | |||||
| signals_->InvalidateComputer(); | |||||
| source->ClearNodes(); | |||||
| source->ClearValueNodes(); | |||||
| source->ClearFuncGraphCNodesIndex(); | |||||
| source->ClearFreeVariables(); | |||||
| source->ClearFuncGraphsUsed(); | |||||
| source->ClearJFuncGraphs(); | |||||
| } | |||||
| FuncGraphTransaction FuncGraphManager::Transact() { | FuncGraphTransaction FuncGraphManager::Transact() { | ||||
| auto tr = FuncGraphTransaction(this); | auto tr = FuncGraphTransaction(this); | ||||
| return tr; | return tr; | ||||
| @@ -610,54 +635,14 @@ void FuncGraphTransaction::Commit() { | |||||
| } | } | ||||
| FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager) | FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager) | ||||
| : manager_(manager), include_func_graph_none_(false) { | |||||
| manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph); | |||||
| manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph); | |||||
| manager_->signals()->AddEdge.connect(this, &FuncGraphAnalysis::OnAddEdge); | |||||
| manager_->signals()->DropEdge.connect(this, &FuncGraphAnalysis::OnDropEdge); | |||||
| manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode); | |||||
| } | |||||
| NodesCollector::NodesCollector(const FuncGraphManager *const m) : DepCollector(m), nodes_analysis_() { | |||||
| include_func_graph_none_ = true; | |||||
| nodes_analysis_[nullptr] = AnfNodeSet(); | |||||
| manager_->signals()->AddNode.connect(this, &NodesCollector::OnAddNode); | |||||
| manager_->signals()->DropNode.connect(this, &NodesCollector::OnDropNode); | |||||
| } | |||||
| void NodesCollector::OnAddNode(AnfNodePtr n) { | |||||
| if (nodes_analysis_.find(n->func_graph()) == nodes_analysis_.end()) { | |||||
| nodes_analysis_[n->func_graph()] = AnfNodeSet(); | |||||
| } | |||||
| nodes_analysis_[n->func_graph()].add(n); | |||||
| } | |||||
| void NodesCollector::OnDropNode(AnfNodePtr n) { | |||||
| (void)nodes_analysis_[n->func_graph()].erase(n); | |||||
| auto graph = n->func_graph(); | |||||
| // Remove the node from order list. | |||||
| if (graph) { | |||||
| graph->EraseUnusedNodeInOrder(n); | |||||
| } | |||||
| } | |||||
| void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||||
| // change the owner of node except for the src's return node | |||||
| for (auto &it : nodes_analysis_[src]) { | |||||
| nodes_analysis_[dst].add(it); | |||||
| } | |||||
| (void)nodes_analysis_.erase(src); | |||||
| } | |||||
| void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); } | |||||
| : manager_(manager), include_func_graph_none_(false) {} | |||||
| DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { | DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { | ||||
| MS_EXCEPTION_IF_NULL(manager_); | MS_EXCEPTION_IF_NULL(manager_); | ||||
| manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector); | |||||
| } | } | ||||
| void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); } | |||||
| void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } | void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } | ||||
| template <typename ValueT, class CollectorHash, class CollectorEqual> | template <typename ValueT, class CollectorHash, class CollectorEqual> | ||||
| @@ -706,65 +691,6 @@ bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Mod(const F | |||||
| } | } | ||||
| } | } | ||||
| void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (inp->isa<ValueNode>()) { | |||||
| (void)Mod(node->func_graph(), inp, direction); | |||||
| } | |||||
| } | |||||
| void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||||
| for (auto &it : count_nodes_map_[src]) { | |||||
| (void)Inc(dst, it.first, it.second); | |||||
| } | |||||
| (void)count_nodes_map_.erase(src); | |||||
| } | |||||
| void FuncGraphUsersCNodeIndexCollector::OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, | |||||
| EdgeProcessDirection direction) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (IsValueNode<FuncGraph>(inp)) { | |||||
| (void)Mod(GetValueNode<FuncGraphPtr>(inp), std::make_shared<CNodeIndexPair>(std::make_pair(node, index)), | |||||
| direction); | |||||
| } | |||||
| } | |||||
| void FuncGraphUsersCNodeIndexCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||||
| for (auto &it : count_nodes_map_[src]) { | |||||
| // Ignore the user graph who may own itself. | |||||
| if (dst != it.first->first->func_graph()) { | |||||
| (void)Inc(dst, it.first, it.second); | |||||
| } | |||||
| } | |||||
| (void)count_nodes_map_.erase(src); | |||||
| } | |||||
| void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(inp); | |||||
| FuncGraphPtr fg1 = node->func_graph(); | |||||
| FuncGraphPtr fg2 = inp->func_graph(); | |||||
| if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) { | |||||
| (void)Mod(fg1, inp, direction); | |||||
| } | |||||
| } | |||||
| void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||||
| for (auto &it : count_nodes_map_[src]) { | |||||
| FuncGraphPtr fg2 = it.first->func_graph(); | |||||
| if (fg2 != dst) { | |||||
| (void)Inc(dst, it.first, it.second); | |||||
| } | |||||
| } | |||||
| (void)count_nodes_map_.erase(src); | |||||
| } | |||||
| static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) { | |||||
| FuncGraphPtr gn = std::make_shared<FuncGraph>(); | |||||
| (void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg))); | |||||
| return gn; | |||||
| } | |||||
| bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { | bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { | ||||
| auto &d = count_func_graphs_map_[func_graph]; | auto &d = count_func_graphs_map_[func_graph]; | ||||
| if (d.count(key) == 0) { | if (d.count(key) == 0) { | ||||
| @@ -804,87 +730,6 @@ bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGr | |||||
| } | } | ||||
| } | } | ||||
| void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(inp); | |||||
| FuncGraphPtr fg1 = node->func_graph(); | |||||
| FuncGraphPtr fg2 = inp->func_graph(); | |||||
| if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) { | |||||
| (void)Mod(fg2, fg1, direction); | |||||
| } | |||||
| } | |||||
| void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||||
| for (auto &it : count_func_graphs_map_[src]) { | |||||
| FuncGraphPtr fg = it.first; | |||||
| if (fg != dst) { | |||||
| (void)Inc(dst, fg, it.second); | |||||
| } | |||||
| } | |||||
| (void)count_func_graphs_map_.erase(src); | |||||
| } | |||||
| void FuncGraphParentsDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| FuncGraphPtr fg1 = node->func_graph(); | |||||
| // possible child parent | |||||
| if (IsValueNode<FuncGraph>(inp)) { | |||||
| FuncGraphPtr fg2 = GetValueNode<FuncGraphPtr>(inp); | |||||
| if (Mod(fg1, ParentProxy(fg2), direction)) { | |||||
| manager_->signals()->InvalidateComputer(); | |||||
| } | |||||
| } | |||||
| // from fv | |||||
| FuncGraphPtr fg2 = inp->func_graph(); | |||||
| if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) { | |||||
| // node use fv will in here, fg1's node use fg2's node, so fg1 is child and fg2 is parent | |||||
| if (Mod(fg1, fg2, direction)) { | |||||
| manager_->signals()->InvalidateComputer(); | |||||
| } | |||||
| } | |||||
| } | |||||
| void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||||
| for (auto &it : count_func_graphs_map_[src]) { | |||||
| if (it.first != dst) { | |||||
| (void)Inc(dst, it.first, it.second); | |||||
| } | |||||
| } | |||||
| (void)count_func_graphs_map_.erase(src); | |||||
| } | |||||
| void FuncGraphsUsedCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (IsValueNode<FuncGraph>(inp)) { | |||||
| (void)Mod(node->func_graph(), GetValueNode<FuncGraphPtr>(inp), direction); | |||||
| } | |||||
| } | |||||
| void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||||
| // all graph use in src need to change to dst, so meger the to dst use | |||||
| for (auto &it : count_func_graphs_map_[src]) { | |||||
| (void)Inc(dst, it.first, it.second); | |||||
| } | |||||
| (void)count_func_graphs_map_[dst].erase(src); | |||||
| (void)count_func_graphs_map_.erase(src); | |||||
| } | |||||
| void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { | |||||
| if (IsValueNode<FuncGraph>(inp) && IsPrimitiveCNode(node, prim::kPrimJ)) { | |||||
| (void)Mod(node->func_graph(), GetValueNode<FuncGraphPtr>(inp), direction); | |||||
| MS_LOG(DEBUG) << node->func_graph()->ToString() << " users func graph " | |||||
| << GetValueNode<FuncGraphPtr>(inp)->ToString() << " which contains J(func_graph), dir: " << direction; | |||||
| } | |||||
| } | |||||
| void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||||
| // all graph use in src need to change to dst, so meger the to dst use | |||||
| for (auto &it : count_func_graphs_map_[src]) { | |||||
| (void)Inc(dst, it.first, it.second); | |||||
| } | |||||
| (void)count_func_graphs_map_.erase(src); | |||||
| } | |||||
| DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { | DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { | ||||
| MS_EXCEPTION_IF_NULL(manager_); | MS_EXCEPTION_IF_NULL(manager_); | ||||
| manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); | manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); | ||||
| @@ -905,22 +750,24 @@ void DepComputer::Recompute(const FuncGraphPtr &fg) { | |||||
| } | } | ||||
| } | } | ||||
| FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) { | |||||
| if (path == nullptr || path->contains(fg)) { | |||||
| FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, size_t seen_num) { | |||||
| if (fg->seen_ == seen_num) { | |||||
| return std::make_shared<FuncGraphSet>(); | return std::make_shared<FuncGraphSet>(); | ||||
| } | } | ||||
| FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>(); | FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>(); | ||||
| FuncGraphToFuncGraphCounterMap &deps = *all_parents_direct_; | |||||
| for (auto &dep : deps[fg]) { | |||||
| MS_EXCEPTION_IF_NULL(dep.first); | |||||
| auto proxy = dep.first->transforms().find("proxy"); | |||||
| if (proxy != dep.first->transforms().end()) { | |||||
| path->add(fg); | |||||
| auto gt = proxy->second.func_graph(); | |||||
| parents->update(SeekParents(gt, path)); | |||||
| } else { | |||||
| parents->add(dep.first); | |||||
| } | |||||
| // Append all the fvs in fg. | |||||
| auto &fvs = fg->free_variables(); | |||||
| for (auto fv : fvs) { | |||||
| parents->add(fv.first->func_graph()); | |||||
| } | |||||
| // Search the fv in fg's child func graph. | |||||
| auto &fgs = fg->func_graphs_used(); | |||||
| for (auto &item : fgs) { | |||||
| fg->seen_ = seen_num; | |||||
| auto gt = item.first; | |||||
| parents->update(SeekParents(gt, seen_num)); | |||||
| } | } | ||||
| (void)parents->erase(fg); | (void)parents->erase(fg); | ||||
| return parents; | return parents; | ||||
| @@ -928,10 +775,7 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f | |||||
| void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { | void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { | ||||
| MS_EXCEPTION_IF_NULL(fg); | MS_EXCEPTION_IF_NULL(fg); | ||||
| all_parents_direct_ = &(manager_->func_graph_parents_direct()); | |||||
| MS_LOG(DEBUG) << fg->ToString() << " total func graph dep size:" << (*all_parents_direct_)[fg].size(); | |||||
| func_graph_parents_total_analysis_[fg].update(SeekParents(fg)); | |||||
| MS_LOG(DEBUG) << "FuncGraphParentsTotalComputer end: " << func_graph_parents_total_analysis_[fg].size(); | |||||
| func_graph_parents_total_analysis_[fg].update(SeekParents(fg, NewFgSeenGeneration())); | |||||
| } | } | ||||
| bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) { | bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) { | ||||
| @@ -1001,21 +845,23 @@ void FVTotalComputer::RealRecompute() { | |||||
| } | } | ||||
| for (auto &fg : manager->func_graphs()) { | for (auto &fg : manager->func_graphs()) { | ||||
| AnfNodeCounterMap items = manager->free_variables_direct()[fg]; | |||||
| AnfNodeCounterMap items = fg->free_variables(); | |||||
| for (auto &iter : items) { | for (auto &iter : items) { | ||||
| auto curr = fg; | auto curr = fg; | ||||
| while (curr) { | |||||
| while (curr != nullptr) { | |||||
| (void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second); | (void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second); | ||||
| curr = manager->parent(curr); | curr = manager->parent(curr); | ||||
| const AnfNodeSet &nodes = manager->nodes()[curr]; | |||||
| if (nodes.contains(iter.first)) { | |||||
| break; | |||||
| if (curr != nullptr) { | |||||
| const AnfNodeSet &all_nodes = curr->nodes(); | |||||
| if (all_nodes.contains(iter.first)) { | |||||
| break; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| auto items_fg = manager->func_graphs_used()[fg]; | |||||
| for (auto &iter : items_fg) { | |||||
| auto &used = fg->func_graphs_used(); | |||||
| for (auto &iter : used) { | |||||
| auto p = manager->parent(iter.first); | auto p = manager->parent(iter.first); | ||||
| if (p == nullptr) { | if (p == nullptr) { | ||||
| continue; | continue; | ||||
| @@ -1041,7 +887,6 @@ void FVTotalComputer::RealRecompute() { | |||||
| void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { | void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { | ||||
| MS_EXCEPTION_IF_NULL(manager_); | MS_EXCEPTION_IF_NULL(manager_); | ||||
| auto &used = this->manager_->func_graphs_used(); | |||||
| std::vector<FuncGraphPtr> todo; | std::vector<FuncGraphPtr> todo; | ||||
| std::vector<FuncGraphPtr> todo_new; | std::vector<FuncGraphPtr> todo_new; | ||||
| @@ -1049,7 +894,7 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { | |||||
| while (!todo.empty()) { | while (!todo.empty()) { | ||||
| todo_new.clear(); | todo_new.clear(); | ||||
| for (auto > : todo) { | for (auto > : todo) { | ||||
| for (auto &item : used[gt]) { | |||||
| for (auto &item : gt->func_graphs_used()) { | |||||
| auto used_fg = item.first; | auto used_fg = item.first; | ||||
| if (used_fg == fg) { | if (used_fg == fg) { | ||||
| func_graph_used_total_analysis_[fg].add(used_fg); | func_graph_used_total_analysis_[fg].add(used_fg); | ||||
| @@ -1068,7 +913,6 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { | |||||
| bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) { | bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) { | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| auto &used = manager->func_graphs_used(); | |||||
| std::vector<FuncGraphPtr> todo; | std::vector<FuncGraphPtr> todo; | ||||
| std::vector<FuncGraphPtr> todo_new; | std::vector<FuncGraphPtr> todo_new; | ||||
| todo.push_back(fg); | todo.push_back(fg); | ||||
| @@ -1076,7 +920,7 @@ bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &f | |||||
| while (!todo.empty()) { | while (!todo.empty()) { | ||||
| todo_new.clear(); | todo_new.clear(); | ||||
| for (auto > : todo) { | for (auto > : todo) { | ||||
| for (auto &item : used[gt]) { | |||||
| for (auto &item : gt->func_graphs_used()) { | |||||
| auto used_g = item.first; | auto used_g = item.first; | ||||
| if (used_g == fg) { | if (used_g == fg) { | ||||
| return true; | return true; | ||||
| @@ -1108,8 +952,8 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F | |||||
| } | } | ||||
| } else { | } else { | ||||
| trace->push_back(fg); | trace->push_back(fg); | ||||
| auto &used_fgs = manager_->func_graphs_used()[fg]; | |||||
| for (auto iter = used_fgs.begin(); iter != used_fgs.end(); (void)iter++) { | |||||
| auto &items = fg->func_graphs_used(); | |||||
| for (auto iter = items.begin(); iter != items.end(); (void)iter++) { | |||||
| CheckRecursiveGraphs(iter->first, trace); | CheckRecursiveGraphs(iter->first, trace); | ||||
| } | } | ||||
| trace->pop_back(); | trace->pop_back(); | ||||
| @@ -1119,31 +963,28 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F | |||||
| } | } | ||||
| } | } | ||||
| bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) { | |||||
| MS_EXCEPTION_IF_NULL(path); | |||||
| if (path->contains(fg)) { | |||||
| bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) { | |||||
| if (fg->seen_ == seen_num) { | |||||
| MS_LOG(DEBUG) << fg->ToString() << " had been checked"; | MS_LOG(DEBUG) << fg->ToString() << " had been checked"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(manager_); | |||||
| auto &func_graph_counter_map = manager_->func_graph_j_direct(); | |||||
| if (!func_graph_counter_map[fg].empty()) { | |||||
| auto &j_fgs = fg->j_func_graphs(); | |||||
| if (!j_fgs.empty()) { | |||||
| // check g1->J(fg)->g2->g cycle; | // check g1->J(fg)->g2->g cycle; | ||||
| auto contains_j = | |||||
| std::find_if(func_graph_counter_map[fg].begin(), func_graph_counter_map[fg].end(), | |||||
| [path](const std::pair<FuncGraphPtr, int> iter) { return !path->contains(iter.first); }); | |||||
| if (contains_j != func_graph_counter_map[fg].end()) { | |||||
| auto contains_j = std::find_if(j_fgs.begin(), j_fgs.end(), [seen_num](const std::pair<FuncGraphPtr, int> iter) { | |||||
| return iter.first->seen_ != seen_num; | |||||
| }); | |||||
| if (contains_j != j_fgs.end()) { | |||||
| MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")"; | MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")"; | ||||
| return true; | return true; | ||||
| } | } | ||||
| } | } | ||||
| path->add(fg); | |||||
| fg->seen_ = seen_num; | |||||
| // check if func graphs used contains J(func_graph); | // check if func graphs used contains J(func_graph); | ||||
| auto &used = this->manager_->func_graphs_used(); | |||||
| for (auto &item : used[fg]) { | |||||
| for (auto &item : fg->func_graphs_used()) { | |||||
| auto used_g = item.first; | auto used_g = item.first; | ||||
| if (SeekJ(used_g, path)) { | |||||
| if (SeekJ(used_g, seen_num)) { | |||||
| MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)"; | MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)"; | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -1153,7 +994,6 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt | |||||
| } | } | ||||
| void FuncGraphJTotalComputer::RealRecompute(FuncGraphPtr fg) { | void FuncGraphJTotalComputer::RealRecompute(FuncGraphPtr fg) { | ||||
| std::shared_ptr<FuncGraphSet> path = std::make_shared<FuncGraphSet>(); | |||||
| this->j_total_analysis_[fg] = SeekJ(fg, path); | |||||
| this->j_total_analysis_[fg] = SeekJ(fg, NewFgSeenGeneration()); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -140,44 +140,6 @@ class FuncGraphAnalysis { | |||||
| using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>; | using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>; | ||||
| // graphs analysis which compute in write, read needn't recompute | |||||
| class DepCollector : public FuncGraphAnalysis { | |||||
| public: | |||||
| explicit DepCollector(const FuncGraphManager *manager); | |||||
| ~DepCollector() override = default; | |||||
| void Reset() { ExtraReset(); } | |||||
| void OnInvalidateCollector() { Reset(); } | |||||
| protected: | |||||
| // inherit from FuncGraphAnalysis | |||||
| void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override; | |||||
| void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override; | |||||
| // subclass can override; | |||||
| virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {} | |||||
| }; | |||||
| class NodesCollector final : public DepCollector { | |||||
| public: | |||||
| explicit NodesCollector(const FuncGraphManager *m); | |||||
| ~NodesCollector() override = default; | |||||
| const FuncGraphToAnfNodeMap &nodes_analysis() const { return nodes_analysis_; } | |||||
| size_t size() const override { return nodes_analysis_.size(); } | |||||
| void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); } | |||||
| void OnDropFuncGraph(FuncGraphPtr fg) override { (void)nodes_analysis_.erase(fg); } | |||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||||
| FuncGraphToAnfNodeMap nodes_analysis_; | |||||
| protected: | |||||
| void ExtraReset() override { nodes_analysis_.clear(); } | |||||
| void OnAddNode(AnfNodePtr n) override; | |||||
| void OnDropNode(AnfNodePtr n) override; | |||||
| }; | |||||
| struct CNodeIndexHasher { | struct CNodeIndexHasher { | ||||
| std::size_t operator()(const CNodeIndexPairPtr pair) const { | std::size_t operator()(const CNodeIndexPairPtr pair) const { | ||||
| MS_EXCEPTION_IF_NULL(pair); | MS_EXCEPTION_IF_NULL(pair); | ||||
| @@ -204,59 +166,21 @@ struct CNodeIndexEqual { | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>> | |||||
| class CounterAnfNodeCollector : public DepCollector { | |||||
| public: | |||||
| explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} | |||||
| ~CounterAnfNodeCollector() override = default; | |||||
| FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; } | |||||
| size_t size() const override { return count_nodes_map_.size(); } | |||||
| void OnAddFuncGraph(FuncGraphPtr fg) final { | |||||
| count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>(); | |||||
| } | |||||
| void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } | |||||
| bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count); | |||||
| bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count); | |||||
| bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count); | |||||
| FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_; | |||||
| protected: | |||||
| void ExtraReset() override { count_nodes_map_.clear(); } | |||||
| }; | |||||
| class ValueNodesCollector final : public CounterAnfNodeCollector<AnfNodePtr> { | |||||
| public: | |||||
| explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||||
| ~ValueNodesCollector() override = default; | |||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||||
| protected: | |||||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | |||||
| }; | |||||
| // Record the CNode and its input index, who points to the function graph. | |||||
| class FuncGraphUsersCNodeIndexCollector final | |||||
| : public CounterAnfNodeCollector<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> { | |||||
| // graphs analysis which compute in write, read needn't recompute | |||||
| class DepCollector : public FuncGraphAnalysis { | |||||
| public: | public: | ||||
| explicit FuncGraphUsersCNodeIndexCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||||
| ~FuncGraphUsersCNodeIndexCollector() override = default; | |||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||||
| protected: | |||||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | |||||
| }; | |||||
| explicit DepCollector(const FuncGraphManager *manager); | |||||
| ~DepCollector() override = default; | |||||
| class FVDirectCollector final : public CounterAnfNodeCollector<AnfNodePtr> { | |||||
| public: | |||||
| explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||||
| ~FVDirectCollector() override = default; | |||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||||
| void Reset() { ExtraReset(); } | |||||
| void OnInvalidateCollector() { Reset(); } | |||||
| protected: | protected: | ||||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | |||||
| // inherit from FuncGraphAnalysis | |||||
| void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override; | |||||
| void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override; | |||||
| // subclass can override; | |||||
| virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {} | |||||
| }; | }; | ||||
| class CounterFuncGraphCollector : public DepCollector { | class CounterFuncGraphCollector : public DepCollector { | ||||
| @@ -278,50 +202,27 @@ class CounterFuncGraphCollector : public DepCollector { | |||||
| void ExtraReset() override { count_func_graphs_map_.clear(); } | void ExtraReset() override { count_func_graphs_map_.clear(); } | ||||
| }; | }; | ||||
| class FuncGraphChildDirect final : public CounterFuncGraphCollector { | |||||
| public: | |||||
| explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||||
| ~FuncGraphChildDirect() override = default; | |||||
| protected: | |||||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | |||||
| }; | |||||
| // graph's all parents, parentsdirect have a map, which key is graph, value is this graph's all direct and proxy | |||||
| // parents: | |||||
| // 1.proxy parent: graph g use graph f, key is g, value is ParentProxy(f) because f's parent will be g's parent | |||||
| // 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f | |||||
| class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector { | |||||
| template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>> | |||||
| class CounterAnfNodeCollector : public DepCollector { | |||||
| public: | public: | ||||
| explicit FuncGraphParentsDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||||
| ~FuncGraphParentsDirectCollector() override = default; | |||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||||
| protected: | |||||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | |||||
| }; | |||||
| explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} | |||||
| ~CounterAnfNodeCollector() override = default; | |||||
| FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; } | |||||
| // graph's all used graphs: key is g, value is g used graph | |||||
| class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { | |||||
| public: | |||||
| explicit FuncGraphsUsedCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||||
| ~FuncGraphsUsedCollector() override = default; | |||||
| size_t size() const override { return count_nodes_map_.size(); } | |||||
| void OnAddFuncGraph(FuncGraphPtr fg) final { | |||||
| count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>(); | |||||
| } | |||||
| void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } | |||||
| protected: | |||||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | |||||
| }; | |||||
| bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count); | |||||
| bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count); | |||||
| bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count); | |||||
| class FuncGraphJDirectCollector final : public CounterFuncGraphCollector { | |||||
| public: | |||||
| explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||||
| void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override; | |||||
| ~FuncGraphJDirectCollector() override = default; | |||||
| FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_; | |||||
| protected: | protected: | ||||
| void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; | |||||
| void ExtraReset() override { count_nodes_map_.clear(); } | |||||
| }; | }; | ||||
| using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>; | using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>; | ||||
| @@ -367,8 +268,8 @@ class DepComputer : public FuncGraphAnalysis { | |||||
| // graph g's all direct or proxy parents | // graph g's all direct or proxy parents | ||||
| class FuncGraphParentsTotalComputer final : public DepComputer { | class FuncGraphParentsTotalComputer final : public DepComputer { | ||||
| public: | public: | ||||
| explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m), all_parents_direct_(nullptr) {} | |||||
| ~FuncGraphParentsTotalComputer() override { all_parents_direct_ = nullptr; } | |||||
| explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m) {} | |||||
| ~FuncGraphParentsTotalComputer() override = default; | |||||
| FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; } | FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; } | ||||
| @@ -382,10 +283,7 @@ class FuncGraphParentsTotalComputer final : public DepComputer { | |||||
| void RealRecompute(FuncGraphPtr fg) override; | void RealRecompute(FuncGraphPtr fg) override; | ||||
| private: | private: | ||||
| FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared<FuncGraphSet>()); | |||||
| // when SeekParents calls itself recursively, it can access these variables by class member | |||||
| // other than pass by formal parameters, it can save 1 parameter for SeekParents(). | |||||
| FuncGraphToFuncGraphCounterMap *all_parents_direct_; | |||||
| FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, size_t seen_num); | |||||
| }; | }; | ||||
| using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>; | using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>; | ||||
| @@ -525,7 +423,7 @@ class FuncGraphJTotalComputer final : public DepComputer { | |||||
| void ExtraReset() override { j_total_analysis_.clear(); } | void ExtraReset() override { j_total_analysis_.clear(); } | ||||
| void RealRecompute(FuncGraphPtr fg) override; | void RealRecompute(FuncGraphPtr fg) override; | ||||
| bool SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path); | |||||
| bool SeekJ(const FuncGraphPtr &fg, size_t seen_num); | |||||
| }; | }; | ||||
| class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | ||||
| @@ -562,30 +460,6 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||||
| NodeUsersMap &node_users() { return node_users_; } | NodeUsersMap &node_users() { return node_users_; } | ||||
| FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; } | |||||
| FuncGraphToAnfNodeCounterMap<AnfNodePtr> &valuenodes() const { return valuenodes_->count_nodes_map_; } | |||||
| FuncGraphToAnfNodeCounterMap<AnfNodePtr> &free_variables_direct() const { | |||||
| return free_variables_direct_->count_nodes_map_; | |||||
| } | |||||
| FuncGraphToAnfNodeCounterMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> &func_graph_cnodes_index() const { | |||||
| return func_graph_cnodes_index_->count_nodes_map_; | |||||
| } | |||||
| FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } | |||||
| FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const { | |||||
| return func_graph_child_direct_->count_func_graphs_map_; | |||||
| } | |||||
| FuncGraphToFuncGraphCounterMap &func_graph_parents_direct() const { | |||||
| return func_graph_parents_direct_->count_func_graphs_map_; | |||||
| } | |||||
| FuncGraphToFuncGraphCounterMap &func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; } | |||||
| FVTotalMap &free_variables_total() const; | FVTotalMap &free_variables_total() const; | ||||
| FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const; | FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const; | ||||
| @@ -610,14 +484,6 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||||
| // Static Analysis | // Static Analysis | ||||
| NodeUsersMap node_users_; | NodeUsersMap node_users_; | ||||
| AnfNodeSet all_nodes_; // managed nodes | AnfNodeSet all_nodes_; // managed nodes | ||||
| std::shared_ptr<NodesCollector> nodes_; | |||||
| std::shared_ptr<ValueNodesCollector> valuenodes_; | |||||
| std::shared_ptr<FVDirectCollector> free_variables_direct_; | |||||
| std::shared_ptr<FuncGraphUsersCNodeIndexCollector> func_graph_cnodes_index_; | |||||
| std::shared_ptr<FuncGraphsUsedCollector> func_graphs_used_; | |||||
| std::shared_ptr<FuncGraphChildDirect> func_graph_child_direct_; | |||||
| std::shared_ptr<FuncGraphParentsDirectCollector> func_graph_parents_direct_; | |||||
| std::shared_ptr<FuncGraphJDirectCollector> func_graph_j_direct_; | |||||
| // Dynamic Analysis | // Dynamic Analysis | ||||
| std::shared_ptr<ParentComputer> func_graph_parent_; | std::shared_ptr<ParentComputer> func_graph_parent_; | ||||
| @@ -630,6 +496,9 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||||
| FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr> &nodes); | FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr> &nodes); | ||||
| void ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges, EdgeTupleCounter *rm_edges, | void ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges, EdgeTupleCounter *rm_edges, | ||||
| Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms); | Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms); | ||||
| void AddEdge(AnfNodePtr node, int index, AnfNodePtr input); | |||||
| void DropEdge(AnfNodePtr node, int index, AnfNodePtr input); | |||||
| void MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target); | |||||
| FuncGraphSet roots_; // managed roots | FuncGraphSet roots_; // managed roots | ||||
| FuncGraphSet func_graphs_; // managed func graphs | FuncGraphSet func_graphs_; // managed func graphs | ||||
| @@ -492,7 +492,7 @@ void DFunctor::MapParamObject() { | |||||
| void DFunctor::MapValueObject() { | void DFunctor::MapValueObject() { | ||||
| // Map ValueNode. | // Map ValueNode. | ||||
| auto manager = resources_->manager(); | auto manager = resources_->manager(); | ||||
| auto &value_nodes = manager->valuenodes()[primal_graph_]; | |||||
| auto &value_nodes = primal_graph_->value_nodes(); | |||||
| for (const auto &value_pair : value_nodes) { | for (const auto &value_pair : value_nodes) { | ||||
| auto node = value_pair.first; | auto node = value_pair.first; | ||||
| auto parent_adjoint = FindAdjoint(node); | auto parent_adjoint = FindAdjoint(node); | ||||
| @@ -119,7 +119,7 @@ FuncGraphPtr TransformGraphCondBranchNodes( | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node; | std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node; | ||||
| // record the node input to be replaced | // record the node input to be replaced | ||||
| NodeInputReplMap repl_node_inputs; | NodeInputReplMap repl_node_inputs; | ||||
| const AnfNodeSet &nodes = manager->nodes()[graph]; | |||||
| const AnfNodeSet &nodes = graph->nodes(); | |||||
| for (auto &node : nodes) { | for (auto &node : nodes) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| @@ -436,7 +436,7 @@ FuncGraphPtr TransformGraphDependNode( | |||||
| ResetSharedOp(); | ResetSharedOp(); | ||||
| std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> repl_node = | std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> repl_node = | ||||
| std::make_shared<std::unordered_map<AnfNodePtr, AnfNodePtr>>(); // record the node to be replaced | std::make_shared<std::unordered_map<AnfNodePtr, AnfNodePtr>>(); // record the node to be replaced | ||||
| const AnfNodeSet &nodes = manager->nodes()[graph]; | |||||
| const AnfNodeSet &nodes = graph->nodes(); | |||||
| for (auto &node : nodes) { | for (auto &node : nodes) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| @@ -391,7 +391,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) { | |||||
| FuncGraphPtr func_graph = res->func_graph(); | FuncGraphPtr func_graph = res->func_graph(); | ||||
| auto manager = res->manager(); | auto manager = res->manager(); | ||||
| // Remove duplicated value nodes, due to replace operation, can't use reference. | // Remove duplicated value nodes, due to replace operation, can't use reference. | ||||
| auto value_nodes = manager->valuenodes()[func_graph]; | |||||
| auto value_nodes = func_graph->value_nodes(); | |||||
| HashCache hash_cache; | HashCache hash_cache; | ||||
| HashValue hashes; | HashValue hashes; | ||||
| for (const auto &value_pair : value_nodes) { | for (const auto &value_pair : value_nodes) { | ||||
| @@ -488,12 +488,12 @@ void CompileGraph::AddExternal(const LinConvertResult &result) { | |||||
| void TraverseGraphMap( | void TraverseGraphMap( | ||||
| const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, | const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, | ||||
| const FuncGraphToAnfNodeCounterMap<AnfNodePtr> &cts, | |||||
| const FuncGraphSet &fgs, | |||||
| const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { | const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { | ||||
| MS_EXCEPTION_IF_NULL(manager_ptr); | MS_EXCEPTION_IF_NULL(manager_ptr); | ||||
| MS_EXCEPTION_IF_NULL(tr); | MS_EXCEPTION_IF_NULL(tr); | ||||
| for (const auto &ct_graphs : cts) { | |||||
| for (const auto &ct_any : ct_graphs.second) { | |||||
| for (const auto &fg : fgs) { | |||||
| for (const auto &ct_any : fg->value_nodes()) { | |||||
| AnfNodePtr const_primitive_node = ct_any.first; | AnfNodePtr const_primitive_node = ct_any.first; | ||||
| if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) { | if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) { | ||||
| auto users = manager_ptr->node_users()[const_primitive_node]; | auto users = manager_ptr->node_users()[const_primitive_node]; | ||||
| @@ -553,8 +553,8 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) { | |||||
| }; | }; | ||||
| FuncGraphTransaction tr = manager_ptr->Transact(); | FuncGraphTransaction tr = manager_ptr->Transact(); | ||||
| auto &cts = manager_ptr->valuenodes(); | |||||
| TraverseGraphMap(manager_ptr, &tr, cts, get_prim_graph); | |||||
| auto &fgs = manager_ptr->func_graphs(); | |||||
| TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph); | |||||
| return graph; | return graph; | ||||
| } | } | ||||
| @@ -132,18 +132,6 @@ class NestingSpecs { | |||||
| CheckAnfNodeCounter(counter_p); | CheckAnfNodeCounter(counter_p); | ||||
| return; | return; | ||||
| } | } | ||||
| auto counter_pair = dynamic_pointer_cast<CounterAnfNodeCollector<CNodeIndexPairPtr>>(results); | |||||
| if (counter_pair != nullptr) { | |||||
| CheckCNodeIndexPairCounter(counter_pair); | |||||
| return; | |||||
| } | |||||
| auto nodes = dynamic_pointer_cast<NodesCollector>(results); | |||||
| if (nodes != nullptr) { | |||||
| CheckNodes(nodes); | |||||
| return; | |||||
| } | |||||
| } | } | ||||
| private: | private: | ||||
| @@ -205,33 +193,7 @@ class NestingSpecs { | |||||
| ASSERT_EQ(clean_results, expected_); | ASSERT_EQ(clean_results, expected_); | ||||
| } | } | ||||
| void CheckNodes(std::shared_ptr<NodesCollector> results) { | |||||
| std::map<std::string, std::set<std::string>> clean_results; | |||||
| for (auto& iter : results->nodes_analysis()) { | |||||
| auto key = iter.first; | |||||
| auto value = iter.second; | |||||
| if (key == nullptr) { | |||||
| continue; | |||||
| } | |||||
| std::string k = Name(key); | |||||
| std::set<std::string> v; | |||||
| for (auto& node : value) { | |||||
| if (!node->isa<CNode>() && !Name(node).empty()) { | |||||
| v.insert(Name(node)); | |||||
| } | |||||
| } | |||||
| if (!v.empty()) { | |||||
| clean_results[k] = v; | |||||
| } | |||||
| } | |||||
| ASSERT_EQ(clean_results, expected_); | |||||
| } | |||||
| // Add CheckNesting function | // Add CheckNesting function | ||||
| void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) { | void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) { | ||||
| std::map<std::string, std::set<std::string>> clean_results; | std::map<std::string, std::set<std::string>> clean_results; | ||||
| for (auto& iter : results->count_nodes_map()) { | for (auto& iter : results->count_nodes_map()) { | ||||
| @@ -258,32 +220,6 @@ class NestingSpecs { | |||||
| ASSERT_EQ(clean_results, expected_); | ASSERT_EQ(clean_results, expected_); | ||||
| } | } | ||||
| void CheckCNodeIndexPairCounter(std::shared_ptr<CounterAnfNodeCollector<CNodeIndexPairPtr>> results) { | |||||
| std::map<std::string, std::set<std::string>> clean_results; | |||||
| for (auto& iter : results->count_nodes_map()) { | |||||
| auto key = iter.first; | |||||
| auto value = iter.second; | |||||
| if (key == nullptr) { | |||||
| continue; | |||||
| } | |||||
| std::string k = Name(key); | |||||
| std::set<std::string> v; | |||||
| for (auto& node : value) { | |||||
| auto fg = node.first->first; | |||||
| if (!Name(fg).empty()) { | |||||
| v.insert(Name(fg)); | |||||
| } | |||||
| } | |||||
| if (!v.empty()) { | |||||
| clean_results[k] = v; | |||||
| } | |||||
| } | |||||
| ASSERT_EQ(clean_results, expected_); | |||||
| } | |||||
| void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) { | void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) { | ||||
| std::map<std::string, std::set<std::string>> clean_results; | std::map<std::string, std::set<std::string>> clean_results; | ||||
| for (auto& iter : results->count_func_graphs_map()) { | for (auto& iter : results->count_func_graphs_map()) { | ||||
| @@ -471,17 +407,10 @@ std::vector<FuncGraphPtr> MakeNestedGraph2() { | |||||
| } | } | ||||
| // Add TestManager::CheckManager function to checkout the result | // Add TestManager::CheckManager function to checkout the result | ||||
| void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) { | void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) { | ||||
| auto size = mng->func_graphs().size(); | auto size = mng->func_graphs().size(); | ||||
| ASSERT_EQ(size + 1, mng->nodes().size()); | |||||
| ASSERT_EQ(size, mng->free_variables_total().size()); | ASSERT_EQ(size, mng->free_variables_total().size()); | ||||
| ASSERT_EQ(size, mng->valuenodes().size()); | |||||
| ASSERT_EQ(size, mng->free_variables_direct().size()); | |||||
| ASSERT_EQ(size, mng->func_graph_cnodes_index().size()); | |||||
| ASSERT_EQ(size, mng->func_graph_parents_direct().size()); | |||||
| ASSERT_EQ(size, mng->func_graphs_used().size()); | |||||
| } | } | ||||
| TEST_F(TestManager, test_scalar_add_manual) { | TEST_F(TestManager, test_scalar_add_manual) { | ||||
| @@ -525,31 +454,26 @@ TEST_F(TestManager, test_nested_manual) { | |||||
| ASSERT_EQ(1, mng->roots().size()); | ASSERT_EQ(1, mng->roots().size()); | ||||
| CheckAnalysisSize(mng); | CheckAnalysisSize(mng); | ||||
| auto nodes = mng->nodes(); | |||||
| ASSERT_EQ(3, nodes[nullptr].size()); | |||||
| ASSERT_EQ(2, nodes[f].size()); | |||||
| ASSERT_EQ(1, nodes[g].size()); | |||||
| ASSERT_EQ(2, f->nodes().size()); | |||||
| ASSERT_EQ(1, g->nodes().size()); | |||||
| auto users = mng->node_users(); | auto users = mng->node_users(); | ||||
| for (auto& iter : users) { | for (auto& iter : users) { | ||||
| ASSERT_EQ(1, iter.second.size()); | ASSERT_EQ(1, iter.second.size()); | ||||
| } | } | ||||
| auto graphs_used = mng->func_graphs_used(); | |||||
| ASSERT_EQ(1, graphs_used[f].size()); | |||||
| ASSERT_EQ(0, graphs_used[g].size()); | |||||
| ASSERT_EQ(1, f->func_graphs_used().size()); | |||||
| ASSERT_EQ(0, g->func_graphs_used().size()); | |||||
| auto fv_direct = mng->free_variables_direct(); | |||||
| ASSERT_EQ(0, fv_direct[f].size()); | |||||
| ASSERT_EQ(1, fv_direct[g].size()); | |||||
| ASSERT_EQ(0, f->free_variables().size()); | |||||
| ASSERT_EQ(1, g->free_variables().size()); | |||||
| auto fv_total = mng->free_variables_total(); | auto fv_total = mng->free_variables_total(); | ||||
| ASSERT_EQ(0, fv_total[f].size()); | ASSERT_EQ(0, fv_total[f].size()); | ||||
| ASSERT_EQ(1, fv_total[g].size()); | ASSERT_EQ(1, fv_total[g].size()); | ||||
| auto cnodes = mng->func_graph_cnodes_index(); | |||||
| ASSERT_EQ(0, cnodes[f].size()); | |||||
| ASSERT_EQ(1, cnodes[g].size()); | |||||
| ASSERT_EQ(0, f->func_graph_cnodes_index().size()); | |||||
| ASSERT_EQ(1, g->func_graph_cnodes_index().size()); | |||||
| } | } | ||||
| TEST_F(TestManager, test_deep_nested2_manual) { | TEST_F(TestManager, test_deep_nested2_manual) { | ||||
| @@ -567,7 +491,7 @@ TEST_F(TestManager, test_deep_nested2_manual) { | |||||
| ASSERT_EQ(3, mng->func_graphs().size()); | ASSERT_EQ(3, mng->func_graphs().size()); | ||||
| ASSERT_EQ(1, mng->roots().size()); | ASSERT_EQ(1, mng->roots().size()); | ||||
| ASSERT_EQ(4, mng->nodes().size()); | |||||
| ASSERT_EQ(4, gfn->nodes().size()); | |||||
| ASSERT_EQ(20, mng->all_nodes().size()); | ASSERT_EQ(20, mng->all_nodes().size()); | ||||
| ASSERT_EQ(25, mng->node_users().size()); | ASSERT_EQ(25, mng->node_users().size()); | ||||
| CheckAnalysisSize(mng); | CheckAnalysisSize(mng); | ||||
| @@ -631,7 +555,6 @@ TEST_F(TestManager, test_deep_nested_manual) { | |||||
| ASSERT_EQ(3, mng->func_graphs().size()); | ASSERT_EQ(3, mng->func_graphs().size()); | ||||
| ASSERT_EQ(1, mng->roots().size()); | ASSERT_EQ(1, mng->roots().size()); | ||||
| ASSERT_EQ(4, mng->nodes().size()); | |||||
| ASSERT_EQ(20, mng->all_nodes().size()); | ASSERT_EQ(20, mng->all_nodes().size()); | ||||
| CheckAnalysisSize(mng); | CheckAnalysisSize(mng); | ||||
| } | } | ||||
| @@ -716,12 +639,12 @@ TEST_F(TestManager, test_drop_root) { | |||||
| FuncGraphPtr fg = getPyFun("ir_get_fn"); | FuncGraphPtr fg = getPyFun("ir_get_fn"); | ||||
| auto mng = Manage(fg); | auto mng = Manage(fg); | ||||
| const FuncGraphToAnfNodeMap& nodes = mng->nodes(); | |||||
| ASSERT_TRUE(nodes.find(fg) != nodes.end()); | |||||
| const auto &fgs = mng->func_graphs(); | |||||
| ASSERT_TRUE(fgs.contains(fg)); | |||||
| FuncGraphSet s; | FuncGraphSet s; | ||||
| s.add(fg); | s.add(fg); | ||||
| mng->MaybeDropFuncGraphs(s); | mng->MaybeDropFuncGraphs(s); | ||||
| ASSERT_TRUE(nodes.find(fg) != nodes.end()); | |||||
| ASSERT_TRUE(fgs.contains(fg)); | |||||
| } | } | ||||
| TEST_F(TestManager, test_keep_roots) { | TEST_F(TestManager, test_keep_roots) { | ||||
| @@ -26,15 +26,14 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| void CheckNoFreeVariables(FuncGraphPtr root) { | void CheckNoFreeVariables(FuncGraphPtr root) { | ||||
| auto mng = Manage(root); | auto mng = Manage(root); | ||||
| for (auto &iter : mng->nodes()) { | |||||
| auto g = iter.first; | |||||
| auto nodes = iter.second; | |||||
| for (auto &iter : mng->func_graphs()) { | |||||
| auto g = iter; | |||||
| if (g == nullptr) { | if (g == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| ASSERT_TRUE(g->parent() == nullptr); | ASSERT_TRUE(g->parent() == nullptr); | ||||
| auto nodes = g->nodes(); | |||||
| for (auto &node : nodes) { | for (auto &node : nodes) { | ||||
| ASSERT_EQ(node->func_graph(), g); | ASSERT_EQ(node->func_graph(), g); | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||