| @@ -279,10 +279,8 @@ void FuncGraphSpecializer::FirstPass() { | |||
| // Specialize CNode in func graphs | |||
| void FuncGraphSpecializer::SecondPass() { | |||
| for (auto &node : BroadFirstSearchGraphCNodes({specialized_func_graph_->get_return()})) { | |||
| if (node->isa<CNode>()) { | |||
| ProcessCNode(node->cast<CNodePtr>()); | |||
| } | |||
| for (auto &cnode : BroadFirstSearchGraphCNodes(specialized_func_graph_->return_node())) { | |||
| ProcessCNode(cnode); | |||
| } | |||
| } | |||
| @@ -639,11 +639,11 @@ std::list<CNodePtr> FuncGraph::GetOrderedCnodes() { | |||
| auto SuccDepends = std::bind(SuccIncludeFV, this_ptr, std::placeholders::_1); | |||
| std::list<CNodePtr> cnodes; | |||
| auto nodes = mindspore::TopoSort(get_return(), SuccDepends, BelongSameGraph); | |||
| auto nodes = mindspore::TopoSort(return_node(), SuccDepends, BelongSameGraph); | |||
| for (const auto &node : nodes) { | |||
| auto cnode = dyn_cast<CNode>(node); | |||
| if (cnode) { | |||
| cnodes.push_back(cnode); | |||
| if (cnode != nullptr) { | |||
| cnodes.emplace_back(std::move(cnode)); | |||
| } | |||
| } | |||
| return cnodes; | |||
| @@ -167,7 +167,7 @@ class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfo | |||
| // get function graph inputs, but parameters | |||
| const std::vector<AnfNodePtr> get_inputs() const final; | |||
| // Return the graph's output, or nullptr if not yet deduced. | |||
| AnfNodePtr output() const; | |||
| AnfNodePtr output() const final; | |||
| void set_output(const AnfNodePtr &value, bool force_new_ret = false); | |||
| const std::vector<AnfNodePtr> ¶meters() const final { return parameters_; } | |||
| @@ -252,6 +252,7 @@ class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfo | |||
| CNodePtr get_return() const final { return return_; } | |||
| void set_return(const CNodePtr &cnode) final { return_ = cnode; } | |||
| const CNodePtr &return_node() const { return return_; } | |||
| FuncGraphManagerPtr manager() const { return manager_.lock(); } | |||
| void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr<FuncGraphManager>(m); } | |||
| @@ -50,47 +50,49 @@ static size_t DumpSortingCircleList(const std::deque<AnfNodePtr> &todo, const An | |||
| } | |||
| std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) { | |||
| constexpr size_t kVecReserve = 64; | |||
| std::vector<AnfNodePtr> res; | |||
| if (root == nullptr) { | |||
| return res; | |||
| } | |||
| res.reserve(kVecReserve); | |||
| size_t seen = NewSeenGeneration(); | |||
| std::deque<AnfNodePtr> todo; | |||
| todo.push_back(root); | |||
| todo.emplace_back(root); | |||
| while (!todo.empty()) { | |||
| AnfNodePtr node = todo.back(); | |||
| AnfNodePtr &node = todo.back(); | |||
| if (node->extra_seen_ == seen) { // We use extra_seen_ as finish flag | |||
| todo.pop_back(); | |||
| continue; | |||
| } | |||
| auto incl = include(node); | |||
| if (node->seen_ == seen) { // We use seen_ as checking flag | |||
| todo.pop_back(); | |||
| node->extra_seen_ = seen; | |||
| if (incl != EXCLUDE) { | |||
| res.push_back(node); | |||
| res.emplace_back(std::move(node)); | |||
| } | |||
| node->extra_seen_ = seen; | |||
| todo.pop_back(); | |||
| continue; | |||
| } | |||
| node->seen_ = seen; | |||
| if (incl == FOLLOW) { | |||
| auto succs = succ(node); | |||
| (void)std::copy_if(succs.begin(), succs.end(), std::back_inserter(todo), [seen, &todo](const AnfNodePtr &next) { | |||
| for (auto &next : succ(node)) { | |||
| if (next == nullptr || next->extra_seen_ == seen) { | |||
| return false; | |||
| continue; | |||
| } | |||
| if (next->seen_ != seen) { | |||
| return true; | |||
| todo.emplace_back(std::move(next)); | |||
| continue; | |||
| } | |||
| if (next->func_graph() != nullptr && next->func_graph()->get_return() == next) { | |||
| return false; | |||
| auto fg = next->func_graph(); | |||
| if (fg != nullptr && fg->return_node() == next) { | |||
| continue; | |||
| } | |||
| // To dump all nodes in a circle. | |||
| MS_LOG(ERROR) << "Graph cycle exists. Circle is: "; | |||
| auto circle_len = DumpSortingCircleList(todo, next, seen); | |||
| MS_LOG(EXCEPTION) << "Graph cycle exists, size: " << circle_len << ", strike node: " << next->DebugString(2); | |||
| }); | |||
| } | |||
| } else if (incl > EXCLUDE) { // Not NOFOLLOW or EXCLUDE | |||
| MS_LOG(EXCEPTION) << "The result of include(node) must be one of: \"follow\", \"nofollow\", \"exclude\""; | |||
| } | |||
| @@ -98,28 +100,25 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c | |||
| return res; | |||
| } | |||
| // search the cnodes inside this graph only | |||
| std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const std::vector<CNodePtr> &starts) { | |||
| std::vector<CNodePtr> todo; | |||
| todo.insert(todo.end(), starts.begin(), starts.end()); | |||
| // Search the cnodes inside this graph only. | |||
| std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const CNodePtr &start) { | |||
| constexpr size_t kVecReserve = 64; | |||
| std::vector<CNodePtr> vec; | |||
| vec.reserve(kVecReserve); | |||
| vec.emplace_back(start); | |||
| auto seen = NewSeenGeneration(); | |||
| size_t top_idx = 0; | |||
| while (top_idx < todo.size()) { | |||
| CNodePtr top = todo[top_idx]; | |||
| top_idx++; | |||
| auto inputs = top->inputs(); | |||
| for (auto &item : inputs) { | |||
| if (item->seen_ == seen) { | |||
| continue; | |||
| } | |||
| if (item->isa<CNode>()) { | |||
| todo.push_back(item->cast<CNodePtr>()); | |||
| for (size_t i = 0; i < vec.size(); ++i) { | |||
| CNodePtr &node = vec[i]; | |||
| node->seen_ = seen; | |||
| auto &inputs = node->inputs(); | |||
| for (auto &input : inputs) { | |||
| auto input_cnode = input->cast<CNodePtr>(); | |||
| if (input_cnode != nullptr && input_cnode->seen_ != seen) { | |||
| vec.emplace_back(std::move(input_cnode)); | |||
| } | |||
| item->seen_ = seen; | |||
| } | |||
| } | |||
| return todo; | |||
| return vec; | |||
| } | |||
| // search the cnode match the predicate inside this graph only | |||
| @@ -192,7 +191,7 @@ std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) { | |||
| if (IsValueNode<FuncGraph>(node)) { | |||
| auto graph = GetValueNode<FuncGraphPtr>(node); | |||
| auto ret = graph->get_return(); | |||
| auto &ret = graph->return_node(); | |||
| if (ret != nullptr) { | |||
| vecs.push_back(ret); | |||
| } | |||
| @@ -215,7 +214,7 @@ std::vector<AnfNodePtr> SuccDeeperSimple(const AnfNodePtr &node) { | |||
| if (IsValueNode<FuncGraph>(node)) { | |||
| auto graph = GetValueNode<FuncGraphPtr>(node); | |||
| auto ret = graph->get_return(); | |||
| auto &ret = graph->return_node(); | |||
| if (ret != nullptr) { | |||
| vecs.push_back(ret); | |||
| } | |||
| @@ -270,8 +269,6 @@ const std::vector<AnfNodePtr> &GetInputs(const AnfNodePtr &node) { | |||
| return empty_inputs; | |||
| } | |||
| IncludeType AlwaysInclude(const AnfNodePtr &) { return FOLLOW; } | |||
| IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node) { | |||
| if (node->func_graph() == fg) { | |||
| return FOLLOW; | |||
| @@ -279,66 +276,4 @@ IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node) { | |||
| return EXCLUDE; | |||
| } | |||
| } | |||
| FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search, const IncludeFunc &include) { | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| Acquire(fg); | |||
| auto vec = search(fg->get_return(), include); | |||
| for (auto &node : vec) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| Acquire(node); | |||
| if (node->func_graph() != nullptr) { | |||
| Acquire(node->func_graph()); | |||
| } | |||
| } | |||
| } | |||
| std::set<FuncGraphPtr> FuncGraphIndex::GetFuncGraphs(const std::string &key) { | |||
| std::set<FuncGraphPtr> func_graphs; | |||
| if (index_func_graph_.find(key) != index_func_graph_.end()) { | |||
| func_graphs = index_func_graph_[key]; | |||
| } | |||
| return func_graphs; | |||
| } | |||
| std::set<AnfNodePtr> FuncGraphIndex::GetNodes(const std::string &key) { | |||
| if (index_node_.find(key) != index_node_.end()) { | |||
| return index_node_[key]; | |||
| } | |||
| return std::set<AnfNodePtr>(); | |||
| } | |||
| FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string &key) { | |||
| if (GetFuncGraphs(key).empty()) { | |||
| return nullptr; | |||
| } | |||
| auto fg = *GetFuncGraphs(key).begin(); | |||
| return fg; | |||
| } | |||
| AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string &key) { | |||
| if (GetNodes(key).empty()) { | |||
| return nullptr; | |||
| } | |||
| auto node = *GetNodes(key).begin(); | |||
| return node; | |||
| } | |||
| void FuncGraphIndex::Acquire(const FuncGraphPtr &key) { | |||
| std::string name = label_manage::Label(key->debug_info()); | |||
| if (!name.empty()) { | |||
| (void)index_func_graph_[name].insert(key); | |||
| } | |||
| } | |||
| void FuncGraphIndex::Acquire(const AnfNodePtr &key) { | |||
| std::string name = label_manage::Label(key->debug_info()); | |||
| if (!name.empty()) { | |||
| (void)index_node_[name].insert(key); | |||
| } | |||
| } | |||
| } // namespace mindspore | |||
| @@ -27,6 +27,7 @@ | |||
| #include <map> | |||
| #include <set> | |||
| #include <string> | |||
| #include <functional> | |||
| #include "ir/anf.h" | |||
| #include "ir/primitive.h" | |||
| @@ -42,10 +43,7 @@ using FilterFunc = std::function<bool(const AnfNodePtr &)>; | |||
| using SuccFunc = std::function<std::vector<AnfNodePtr>(AnfNodePtr)>; | |||
| using SearchFunc = std::function<std::vector<AnfNodePtr>(const AnfNodePtr &, const IncludeFunc &)>; | |||
| using MatchFunc = std::function<bool(const CNodePtr &)>; | |||
| std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include); | |||
| std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include); | |||
| std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include); | |||
| using NodeVisitFunc = std::function<void(const AnfNodePtr &)>; | |||
| std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node); | |||
| std::vector<AnfNodePtr> SuccDeeperSimple(const AnfNodePtr &node); | |||
| @@ -54,49 +52,24 @@ std::vector<AnfNodePtr> SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr & | |||
| const std::vector<AnfNodePtr> &GetInputs(const AnfNodePtr &node); | |||
| IncludeType AlwaysInclude(const AnfNodePtr &node); | |||
| inline IncludeType AlwaysInclude(const AnfNodePtr &) { return FOLLOW; } | |||
| IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node); | |||
| std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); | |||
| std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); | |||
| std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); | |||
| std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include, | |||
| const FilterFunc &filter); | |||
| class FuncGraphManager; | |||
| using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>; | |||
| std::vector<AnfNodePtr> DeepUsersSearch(const AnfNodePtr &root, const IncludeFunc &include, | |||
| const FuncGraphManagerPtr &mng); | |||
| std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming, | |||
| const IncludeFunc &include = AlwaysInclude); | |||
| std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const std::vector<CNodePtr> &starts); | |||
| std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const CNodePtr &start); | |||
| std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(const FuncGraphPtr &root); | |||
| CNodePtr BroadFirstSearchFirstOf(const std::vector<CNodePtr> &starts, const MatchFunc &match_predicate); | |||
| class FuncGraphIndex { | |||
| public: | |||
| explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, | |||
| const IncludeFunc &include = AlwaysInclude); | |||
| FuncGraphIndex(const FuncGraphIndex &) = delete; | |||
| FuncGraphIndex &operator=(const FuncGraphIndex &) = delete; | |||
| virtual ~FuncGraphIndex() {} | |||
| std::set<FuncGraphPtr> GetFuncGraphs(const std::string &key); | |||
| std::set<AnfNodePtr> GetNodes(const std::string &key); | |||
| FuncGraphPtr GetFirstFuncGraph(const std::string &key); | |||
| AnfNodePtr GetFirstNode(const std::string &key); | |||
| private: | |||
| void Acquire(const FuncGraphPtr &key); | |||
| void Acquire(const AnfNodePtr &key); | |||
| std::map<std::string, std::set<FuncGraphPtr>> index_func_graph_; | |||
| std::map<std::string, std::set<AnfNodePtr>> index_node_; | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_IR_GRAPH_UTILS_H_ | |||
| @@ -37,7 +37,10 @@ namespace { | |||
| class DeepFirstSearcher : public AnfIrVisitor { | |||
| public: | |||
| explicit DeepFirstSearcher(const IncludeFunc &include, const FilterFunc &filter = nullptr) | |||
| : include_(include), filter_(filter) {} | |||
| : include_(include), filter_(filter) { | |||
| constexpr size_t kVecReserve = 64; | |||
| res_.reserve(kVecReserve); | |||
| } | |||
| ~DeepFirstSearcher() override = default; | |||
| std::vector<AnfNodePtr> Search(const AnfNodePtr &root) { | |||
| @@ -50,13 +53,10 @@ class DeepFirstSearcher : public AnfIrVisitor { | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->seen_ == seen_) { | |||
| if (node == nullptr || node->seen_ == seen_) { | |||
| return; | |||
| } | |||
| node->seen_ = seen_; | |||
| auto incl = include_(node); | |||
| if (incl == EXCLUDE) { | |||
| return; | |||
| @@ -82,14 +82,12 @@ class DeepScopedGraphSearcher : public DeepFirstSearcher { | |||
| ~DeepScopedGraphSearcher() override = default; | |||
| void Visit(const CNodePtr &cnode) override { | |||
| if (cnode->func_graph() == nullptr) { | |||
| auto fg = cnode->func_graph(); | |||
| if (fg == nullptr) { | |||
| return; | |||
| } | |||
| AnfNodePtr ret = cnode->func_graph()->get_return(); | |||
| if (ret != nullptr) { | |||
| DeepFirstSearcher::Visit(ret); | |||
| } | |||
| AnfNodePtr ret = fg->return_node(); | |||
| DeepFirstSearcher::Visit(ret); | |||
| auto &inputs = cnode->inputs(); | |||
| for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { | |||
| @@ -101,48 +99,18 @@ class DeepScopedGraphSearcher : public DeepFirstSearcher { | |||
| if (!IsValueNode<FuncGraph>(vnode)) { | |||
| return; | |||
| } | |||
| auto graph = GetValueNode<FuncGraphPtr>(vnode); | |||
| AnfNodePtr ret = graph->get_return(); | |||
| if (ret != nullptr) { | |||
| DeepFirstSearcher::Visit(ret); | |||
| } | |||
| auto fg = GetValueNode<FuncGraphPtr>(vnode); | |||
| AnfNodePtr ret = fg->return_node(); | |||
| DeepFirstSearcher::Visit(ret); | |||
| } | |||
| void Visit(const ParameterPtr ¶m) override { | |||
| if (param->func_graph() == nullptr) { | |||
| auto fg = param->func_graph(); | |||
| if (fg == nullptr) { | |||
| return; | |||
| } | |||
| AnfNodePtr ret = param->func_graph()->get_return(); | |||
| if (ret != nullptr) { | |||
| DeepFirstSearcher::Visit(ret); | |||
| } | |||
| } | |||
| }; | |||
| class DeepUsedGraphSearcher : public DeepFirstSearcher { | |||
| public: | |||
| explicit DeepUsedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} | |||
| ~DeepUsedGraphSearcher() override = default; | |||
| void Visit(const CNodePtr &cnode) override { | |||
| auto &inputs = cnode->inputs(); | |||
| for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { | |||
| DeepFirstSearcher::Visit(*iter); | |||
| } | |||
| } | |||
| void Visit(const ValueNodePtr &vnode) override { | |||
| if (!IsValueNode<FuncGraph>(vnode)) { | |||
| return; | |||
| } | |||
| auto graph = GetValueNode<FuncGraphPtr>(vnode); | |||
| AnfNodePtr ret = graph->get_return(); | |||
| if (ret != nullptr) { | |||
| DeepFirstSearcher::Visit(ret); | |||
| } | |||
| AnfNodePtr ret = fg->return_node(); | |||
| DeepFirstSearcher::Visit(ret); | |||
| } | |||
| }; | |||
| @@ -160,24 +128,6 @@ class DeepLinkedGraphSearcher : public DeepFirstSearcher { | |||
| void Visit(const ValueNodePtr &) override {} | |||
| }; | |||
| class DeepUsersSearcher : public DeepFirstSearcher { | |||
| public: | |||
| explicit DeepUsersSearcher(const IncludeFunc &include, const FuncGraphManagerPtr &mng) | |||
| : DeepFirstSearcher(include), mng_(mng) {} | |||
| ~DeepUsersSearcher() override = default; | |||
| void Visit(const CNodePtr &cnode) override { | |||
| auto &users = mng_->node_users()[cnode]; | |||
| for (auto iter = users.begin(); iter != users.end(); ++iter) { | |||
| DeepFirstSearcher::Visit(iter->first); | |||
| } | |||
| } | |||
| void Visit(const ValueNodePtr &) override {} | |||
| private: | |||
| FuncGraphManagerPtr mng_; | |||
| }; | |||
| } // namespace | |||
| // include for if expand the node the search, filter for if put the node to results. | |||
| @@ -190,16 +140,7 @@ std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, | |||
| return DeepFirstSearcher(include, filter).Search(root); | |||
| } | |||
| std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { | |||
| return DeepUsedGraphSearcher(include).Search(root); | |||
| } | |||
| std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { | |||
| return DeepLinkedGraphSearcher(include).Search(root); | |||
| } | |||
| std::vector<AnfNodePtr> DeepUsersSearch(const AnfNodePtr &root, const IncludeFunc &include, | |||
| const FuncGraphManagerPtr &mng) { | |||
| return DeepUsersSearcher(include, mng).Search(root); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -27,6 +27,86 @@ | |||
| #include "base/core_ops.h" | |||
| namespace mindspore { | |||
| class FuncGraphIndex { | |||
| public: | |||
| explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, | |||
| const IncludeFunc &include = AlwaysInclude); | |||
| FuncGraphIndex(const FuncGraphIndex &) = delete; | |||
| FuncGraphIndex &operator=(const FuncGraphIndex &) = delete; | |||
| virtual ~FuncGraphIndex() {} | |||
| std::set<FuncGraphPtr> GetFuncGraphs(const std::string &key); | |||
| std::set<AnfNodePtr> GetNodes(const std::string &key); | |||
| FuncGraphPtr GetFirstFuncGraph(const std::string &key); | |||
| AnfNodePtr GetFirstNode(const std::string &key); | |||
| private: | |||
| void Acquire(const FuncGraphPtr &key); | |||
| void Acquire(const AnfNodePtr &key); | |||
| std::map<std::string, std::set<FuncGraphPtr>> index_func_graph_; | |||
| std::map<std::string, std::set<AnfNodePtr>> index_node_; | |||
| }; | |||
| FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search, const IncludeFunc &include) { | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| Acquire(fg); | |||
| auto vec = search(fg->get_return(), include); | |||
| for (auto &node : vec) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| Acquire(node); | |||
| if (node->func_graph() != nullptr) { | |||
| Acquire(node->func_graph()); | |||
| } | |||
| } | |||
| } | |||
| std::set<FuncGraphPtr> FuncGraphIndex::GetFuncGraphs(const std::string &key) { | |||
| std::set<FuncGraphPtr> func_graphs; | |||
| if (index_func_graph_.find(key) != index_func_graph_.end()) { | |||
| func_graphs = index_func_graph_[key]; | |||
| } | |||
| return func_graphs; | |||
| } | |||
| std::set<AnfNodePtr> FuncGraphIndex::GetNodes(const std::string &key) { | |||
| if (index_node_.find(key) != index_node_.end()) { | |||
| return index_node_[key]; | |||
| } | |||
| return std::set<AnfNodePtr>(); | |||
| } | |||
| FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string &key) { | |||
| if (GetFuncGraphs(key).empty()) { | |||
| return nullptr; | |||
| } | |||
| auto fg = *GetFuncGraphs(key).begin(); | |||
| return fg; | |||
| } | |||
| AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string &key) { | |||
| if (GetNodes(key).empty()) { | |||
| return nullptr; | |||
| } | |||
| auto node = *GetNodes(key).begin(); | |||
| return node; | |||
| } | |||
| void FuncGraphIndex::Acquire(const FuncGraphPtr &key) { | |||
| std::string name = label_manage::Label(key->debug_info()); | |||
| if (!name.empty()) { | |||
| (void)index_func_graph_[name].insert(key); | |||
| } | |||
| } | |||
| void FuncGraphIndex::Acquire(const AnfNodePtr &key) { | |||
| std::string name = label_manage::Label(key->debug_info()); | |||
| if (!name.empty()) { | |||
| (void)index_node_[name].insert(key); | |||
| } | |||
| } | |||
| class TestCloner : public UT::Common { | |||
| public: | |||
| TestCloner() : getPyFun("gtest_input.ir.clone_test", true) { | |||
| @@ -36,7 +116,7 @@ class TestCloner : public UT::Common { | |||
| } | |||
| FuncGraphPtr GraphForInline() { return nullptr; } | |||
| void SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig, const std::vector<AnfNodePtr>& params, | |||
| void SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig, const std::vector<AnfNodePtr> ¶ms, | |||
| FuncGraphPtr target); | |||
| public: | |||
| @@ -48,7 +128,7 @@ class TestCloner : public UT::Common { | |||
| }; | |||
| void TestCloner::SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig, | |||
| const std::vector<AnfNodePtr>& params, FuncGraphPtr target) { | |||
| const std::vector<AnfNodePtr> ¶ms, FuncGraphPtr target) { | |||
| auto g = (*cl)[orig]; | |||
| ASSERT_TRUE(g != target); | |||
| ASSERT_TRUE(g == orig); | |||
| @@ -59,11 +139,11 @@ void TestCloner::SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphP | |||
| AnfNodeSet orig_nodes = AnfNodeSet(DeepLinkedGraphSearch(orig->output())); | |||
| AnfNodeSet new_nodes = AnfNodeSet(DeepLinkedGraphSearch(new_root)); | |||
| for (auto& p : params) { | |||
| for (auto &p : params) { | |||
| ASSERT_TRUE(new_nodes.contains(p)); | |||
| } | |||
| for (auto& node : orig_nodes) { | |||
| for (auto &node : orig_nodes) { | |||
| if (node->func_graph() == orig) { | |||
| ASSERT_TRUE((*cl)[node]); | |||
| } | |||
| @@ -93,7 +173,7 @@ TEST_F(TestCloner, test_clone_simple) { | |||
| std::vector<Primitive> results = {Primitive(prim::kScalarAdd), Primitive(prim::kScalarMul), Primitive("Return")}; | |||
| AnfNodeSet d3 = AnfNodeSet(DeepScopedGraphSearch(g3->get_return())); | |||
| common = d1 & d3; | |||
| for (auto& x : common) { | |||
| for (auto &x : common) { | |||
| ASSERT_TRUE(x->isa<ValueNode>()); | |||
| ASSERT_TRUE(find(results.begin(), results.end(), *x->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>()) != | |||
| results.end()); | |||