Merge pull request !926 from biffex/ir-add-seen-generation-to-accelerate-traverse-the-whole-graphtags/v0.3.0-alpha
| @@ -227,6 +227,12 @@ bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) { | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| size_t NewSeenGeneration() { | |||||
| static size_t seen_generation = 0; | |||||
| return ++seen_generation; | |||||
| } | |||||
| namespace id_generator { | namespace id_generator { | ||||
| static std::unordered_map<std::string, int> node_ids; | static std::unordered_map<std::string, int> node_ids; | ||||
| std::string get_id(const AnfNodePtr &node) { | std::string get_id(const AnfNodePtr &node) { | ||||
| @@ -155,6 +155,7 @@ class AnfNode : public Base { | |||||
| os << node.ToString(); | os << node.ToString(); | ||||
| return os; | return os; | ||||
| } | } | ||||
| size_t seen_{0}; | |||||
| protected: | protected: | ||||
| // Hold a weak ref to Graph as Graph also hold ref to AnfNode. | // Hold a weak ref to Graph as Graph also hold ref to AnfNode. | ||||
| @@ -429,6 +430,9 @@ inline S GetValueNode(const AnfNodePtr &node) { | |||||
| auto s = value->cast<S>(); | auto s = value->cast<S>(); | ||||
| return s; | return s; | ||||
| } | } | ||||
| size_t NewSeenGeneration(); | |||||
| namespace id_generator { | namespace id_generator { | ||||
| std::string get_id(const AnfNodePtr &node); | std::string get_id(const AnfNodePtr &node); | ||||
| void reset_id(); | void reset_id(); | ||||
| @@ -90,20 +90,26 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode | |||||
| bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, | bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, | ||||
| const SubstitutionPtr &transform) const { | const SubstitutionPtr &transform) const { | ||||
| #ifdef ENABLE_PROFILE | |||||
| double start = GetTime(); | |||||
| #endif | |||||
| FuncGraphManagerPtr manager = optimizer->manager(); | FuncGraphManagerPtr manager = optimizer->manager(); | ||||
| std::unordered_set<AnfNodePtr> seen_node; | |||||
| std::deque<AnfNodePtr> todo{root_node}; | |||||
| auto seen = NewSeenGeneration(); | |||||
| // 1024 is for the initial capacity of deque | |||||
| std::deque<AnfNodePtr> todo(1024); | |||||
| todo.push_back(root_node); | |||||
| bool changes = false; | bool changes = false; | ||||
| auto &all_nodes = manager->all_nodes(); | |||||
| while (!todo.empty()) { | while (!todo.empty()) { | ||||
| AnfNodePtr node = todo.front(); | AnfNodePtr node = todo.front(); | ||||
| todo.pop_front(); | todo.pop_front(); | ||||
| // check whether this node has been matched. | // check whether this node has been matched. | ||||
| if (seen_node.find(node) != seen_node.end() || !manager->all_nodes().contains(node)) { | |||||
| if (node == nullptr || node->seen_ == seen || !all_nodes.contains(node)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| (void)seen_node.insert(node); | |||||
| node->seen_ = seen; | |||||
| // select nodes that this transform can be applied. | // select nodes that this transform can be applied. | ||||
| bool is_match = transform->predicate_(node); | bool is_match = transform->predicate_(node); | ||||
| @@ -114,6 +120,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo | |||||
| auto ret = (*transform)(optimizer, node); | auto ret = (*transform)(optimizer, node); | ||||
| if (ret != nullptr && ret != node) { | if (ret != nullptr && ret != node) { | ||||
| change = true; | change = true; | ||||
| changes = true; | |||||
| #ifdef ENABLE_PROFILE | #ifdef ENABLE_PROFILE | ||||
| double t = GetTime(); | double t = GetTime(); | ||||
| #endif | #endif | ||||
| @@ -139,16 +146,20 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo | |||||
| if (change && node_users.find(node) != node_users.end()) { | if (change && node_users.find(node) != node_users.end()) { | ||||
| for (auto &use : node_users[node]) { | for (auto &use : node_users[node]) { | ||||
| auto use_node = use.first; | auto use_node = use.first; | ||||
| if (use_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| todo.push_back(use_node); | todo.push_back(use_node); | ||||
| if (seen_node.find(use_node) != seen_node.end()) { | |||||
| (void)seen_node.erase(use_node); | |||||
| if (use_node->seen_ == seen) { | |||||
| use_node->seen_--; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| changes = changes || change; | |||||
| } | } | ||||
| #ifdef ENABLE_PROFILE | |||||
| MsProfile::StatTime("opt.transform", GetTime() - start); | |||||
| #endif | |||||
| return changes; | return changes; | ||||
| } | } | ||||
| @@ -48,8 +48,8 @@ class Substitution { | |||||
| PredicateFuncType predicate_{nullptr}; | PredicateFuncType predicate_{nullptr}; | ||||
| // an enum to mark this Substitution relation to renormalize pass | // an enum to mark this Substitution relation to renormalize pass | ||||
| RenormAction renorm_action_; | RenormAction renorm_action_; | ||||
| explicit Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, | |||||
| const RenormAction &renorm_action) | |||||
| Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, | |||||
| const RenormAction &renorm_action) | |||||
| : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} | : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} | ||||
| ~Substitution() = default; | ~Substitution() = default; | ||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const; | AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const; | ||||
| @@ -46,17 +46,18 @@ class DeepFirstSearcher : public AnfVisitor { | |||||
| if (root == nullptr) { | if (root == nullptr) { | ||||
| return res_; | return res_; | ||||
| } | } | ||||
| seen_ = NewSeenGeneration(); | |||||
| Visit(root); | Visit(root); | ||||
| return res_; | return res_; | ||||
| } | } | ||||
| void Visit(const AnfNodePtr &node) override { | void Visit(const AnfNodePtr &node) override { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (seen_.count(node) != 0) { | |||||
| if (node->seen_ == seen_) { | |||||
| return; | return; | ||||
| } | } | ||||
| (void)seen_.insert(node); | |||||
| node->seen_ = seen_; | |||||
| auto incl = include_(node); | auto incl = include_(node); | ||||
| if (incl == EXCLUDE) { | if (incl == EXCLUDE) { | ||||
| @@ -70,9 +71,9 @@ class DeepFirstSearcher : public AnfVisitor { | |||||
| } | } | ||||
| private: | private: | ||||
| size_t seen_{0}; | |||||
| IncludeFunc include_; | IncludeFunc include_; | ||||
| std::vector<AnfNodePtr> res_{}; | std::vector<AnfNodePtr> res_{}; | ||||
| std::set<AnfNodePtr> seen_{}; | |||||
| }; | }; | ||||
| class DeepScopedGraphSearcher : public DeepFirstSearcher { | class DeepScopedGraphSearcher : public DeepFirstSearcher { | ||||
| @@ -174,14 +175,14 @@ std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const Incl | |||||
| } | } | ||||
| std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) { | std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) { | ||||
| std::unordered_set<AnfNodePtr> done; | |||||
| size_t seen = NewSeenGeneration(); | |||||
| std::list<AnfNodePtr> todo(1, root); | std::list<AnfNodePtr> todo(1, root); | ||||
| std::unordered_map<AnfNodePtr, size_t> rank; | std::unordered_map<AnfNodePtr, size_t> rank; | ||||
| std::vector<AnfNodePtr> res; | std::vector<AnfNodePtr> res; | ||||
| while (!todo.empty()) { | while (!todo.empty()) { | ||||
| AnfNodePtr node = todo.back(); | AnfNodePtr node = todo.back(); | ||||
| if (done.find(node) != done.end()) { | |||||
| if (node == nullptr || node->seen_ == seen) { | |||||
| todo.pop_back(); | todo.pop_back(); | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -194,7 +195,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c | |||||
| if (incl == FOLLOW) { | if (incl == FOLLOW) { | ||||
| auto succs = succ(node); | auto succs = succ(node); | ||||
| for (const auto i : succs) { | for (const auto i : succs) { | ||||
| if ((done.find(i) == done.end()) | |||||
| if ((i != nullptr && i->seen_ != seen) | |||||
| // Handle the case for 2 subgraphs calls each other. | // Handle the case for 2 subgraphs calls each other. | ||||
| // If the ValueNodeGraph's return is already in the todo list, do not follow it. | // If the ValueNodeGraph's return is already in the todo list, do not follow it. | ||||
| && !((std::find(todo.begin(), todo.end(), i) != todo.end()) && (i->func_graph() != nullptr) && | && !((std::find(todo.begin(), todo.end(), i) != todo.end()) && (i->func_graph() != nullptr) && | ||||
| @@ -206,7 +207,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c | |||||
| } else if (incl == NOFOLLOW) { | } else if (incl == NOFOLLOW) { | ||||
| // do nothing | // do nothing | ||||
| } else if (incl == EXCLUDE) { | } else if (incl == EXCLUDE) { | ||||
| (void)done.insert(node); | |||||
| node->seen_ = seen; | |||||
| todo.pop_back(); | todo.pop_back(); | ||||
| continue; | continue; | ||||
| } else { | } else { | ||||
| @@ -215,7 +216,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c | |||||
| if (cont) { | if (cont) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| (void)done.insert(node); | |||||
| node->seen_ = seen; | |||||
| res.push_back(node); | res.push_back(node); | ||||
| todo.pop_back(); | todo.pop_back(); | ||||
| } | } | ||||