Merge pull request !926 from biffex/ir-add-seen-generation-to-accelerate-traverse-the-whole-graphtags/v0.3.0-alpha
| @@ -227,6 +227,12 @@ bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) { | |||
| } | |||
| return false; | |||
| } | |||
| size_t NewSeenGeneration() { | |||
| static size_t seen_generation = 0; | |||
| return ++seen_generation; | |||
| } | |||
| namespace id_generator { | |||
| static std::unordered_map<std::string, int> node_ids; | |||
| std::string get_id(const AnfNodePtr &node) { | |||
| @@ -155,6 +155,7 @@ class AnfNode : public Base { | |||
| os << node.ToString(); | |||
| return os; | |||
| } | |||
| size_t seen_{0}; | |||
| protected: | |||
| // Hold a weak ref to Graph as Graph also hold ref to AnfNode. | |||
| @@ -429,6 +430,9 @@ inline S GetValueNode(const AnfNodePtr &node) { | |||
| auto s = value->cast<S>(); | |||
| return s; | |||
| } | |||
| size_t NewSeenGeneration(); | |||
| namespace id_generator { | |||
| std::string get_id(const AnfNodePtr &node); | |||
| void reset_id(); | |||
| @@ -90,20 +90,26 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode | |||
| bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, | |||
| const SubstitutionPtr &transform) const { | |||
| #ifdef ENABLE_PROFILE | |||
| double start = GetTime(); | |||
| #endif | |||
| FuncGraphManagerPtr manager = optimizer->manager(); | |||
| std::unordered_set<AnfNodePtr> seen_node; | |||
| std::deque<AnfNodePtr> todo{root_node}; | |||
| auto seen = NewSeenGeneration(); | |||
| // 1024 is for the initial capacity of deque | |||
| std::deque<AnfNodePtr> todo(1024); | |||
| todo.push_back(root_node); | |||
| bool changes = false; | |||
| auto &all_nodes = manager->all_nodes(); | |||
| while (!todo.empty()) { | |||
| AnfNodePtr node = todo.front(); | |||
| todo.pop_front(); | |||
| // check whether this node has been matched. | |||
| if (seen_node.find(node) != seen_node.end() || !manager->all_nodes().contains(node)) { | |||
| if (node == nullptr || node->seen_ == seen || !all_nodes.contains(node)) { | |||
| continue; | |||
| } | |||
| (void)seen_node.insert(node); | |||
| node->seen_ = seen; | |||
| // select nodes that this transform can be applied. | |||
| bool is_match = transform->predicate_(node); | |||
| @@ -114,6 +120,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo | |||
| auto ret = (*transform)(optimizer, node); | |||
| if (ret != nullptr && ret != node) { | |||
| change = true; | |||
| changes = true; | |||
| #ifdef ENABLE_PROFILE | |||
| double t = GetTime(); | |||
| #endif | |||
| @@ -139,16 +146,20 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo | |||
| if (change && node_users.find(node) != node_users.end()) { | |||
| for (auto &use : node_users[node]) { | |||
| auto use_node = use.first; | |||
| if (use_node == nullptr) { | |||
| continue; | |||
| } | |||
| todo.push_back(use_node); | |||
| if (seen_node.find(use_node) != seen_node.end()) { | |||
| (void)seen_node.erase(use_node); | |||
| if (use_node->seen_ == seen) { | |||
| use_node->seen_--; | |||
| } | |||
| } | |||
| } | |||
| changes = changes || change; | |||
| } | |||
| #ifdef ENABLE_PROFILE | |||
| MsProfile::StatTime("opt.transform", GetTime() - start); | |||
| #endif | |||
| return changes; | |||
| } | |||
| @@ -48,8 +48,8 @@ class Substitution { | |||
| PredicateFuncType predicate_{nullptr}; | |||
| // an enum to mark this Substitution relation to renormalize pass | |||
| RenormAction renorm_action_; | |||
| explicit Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, | |||
| const RenormAction &renorm_action) | |||
| Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, | |||
| const RenormAction &renorm_action) | |||
| : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} | |||
| ~Substitution() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const; | |||
| @@ -46,17 +46,18 @@ class DeepFirstSearcher : public AnfVisitor { | |||
| if (root == nullptr) { | |||
| return res_; | |||
| } | |||
| seen_ = NewSeenGeneration(); | |||
| Visit(root); | |||
| return res_; | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (seen_.count(node) != 0) { | |||
| if (node->seen_ == seen_) { | |||
| return; | |||
| } | |||
| (void)seen_.insert(node); | |||
| node->seen_ = seen_; | |||
| auto incl = include_(node); | |||
| if (incl == EXCLUDE) { | |||
| @@ -70,9 +71,9 @@ class DeepFirstSearcher : public AnfVisitor { | |||
| } | |||
| private: | |||
| size_t seen_{0}; | |||
| IncludeFunc include_; | |||
| std::vector<AnfNodePtr> res_{}; | |||
| std::set<AnfNodePtr> seen_{}; | |||
| }; | |||
| class DeepScopedGraphSearcher : public DeepFirstSearcher { | |||
| @@ -174,14 +175,14 @@ std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const Incl | |||
| } | |||
| std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) { | |||
| std::unordered_set<AnfNodePtr> done; | |||
| size_t seen = NewSeenGeneration(); | |||
| std::list<AnfNodePtr> todo(1, root); | |||
| std::unordered_map<AnfNodePtr, size_t> rank; | |||
| std::vector<AnfNodePtr> res; | |||
| while (!todo.empty()) { | |||
| AnfNodePtr node = todo.back(); | |||
| if (done.find(node) != done.end()) { | |||
| if (node == nullptr || node->seen_ == seen) { | |||
| todo.pop_back(); | |||
| continue; | |||
| } | |||
| @@ -194,7 +195,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c | |||
| if (incl == FOLLOW) { | |||
| auto succs = succ(node); | |||
| for (const auto i : succs) { | |||
| if ((done.find(i) == done.end()) | |||
| if ((i != nullptr && i->seen_ != seen) | |||
| // Handle the case for 2 subgraphs calls each other. | |||
| // If the ValueNodeGraph's return is already in the todo list, do not follow it. | |||
| && !((std::find(todo.begin(), todo.end(), i) != todo.end()) && (i->func_graph() != nullptr) && | |||
| @@ -206,7 +207,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c | |||
| } else if (incl == NOFOLLOW) { | |||
| // do nothing | |||
| } else if (incl == EXCLUDE) { | |||
| (void)done.insert(node); | |||
| node->seen_ = seen; | |||
| todo.pop_back(); | |||
| continue; | |||
| } else { | |||
| @@ -215,7 +216,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c | |||
| if (cont) { | |||
| continue; | |||
| } | |||
| (void)done.insert(node); | |||
| node->seen_ = seen; | |||
| res.push_back(node); | |||
| todo.pop_back(); | |||
| } | |||