Browse Source

Optimize graph_utils

tags/v1.6.0
He Wei 4 years ago
parent
commit
205b6357ed
7 changed files with 145 additions and 217 deletions
  1. +2
    -4
      mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc
  2. +3
    -3
      mindspore/core/ir/func_graph.cc
  3. +2
    -1
      mindspore/core/ir/func_graph.h
  4. +32
    -97
      mindspore/core/ir/graph_utils.cc
  5. +5
    -32
      mindspore/core/ir/graph_utils.h
  6. +16
    -75
      mindspore/core/ir/graph_utils_extends.cc
  7. +85
    -5
      tests/ut/cpp/ir/clone_test.cc

+ 2
- 4
mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc View File

@@ -279,10 +279,8 @@ void FuncGraphSpecializer::FirstPass() {

// Specialize CNode in func graphs
void FuncGraphSpecializer::SecondPass() {
for (auto &node : BroadFirstSearchGraphCNodes({specialized_func_graph_->get_return()})) {
if (node->isa<CNode>()) {
ProcessCNode(node->cast<CNodePtr>());
}
for (auto &cnode : BroadFirstSearchGraphCNodes(specialized_func_graph_->return_node())) {
ProcessCNode(cnode);
}
}



+ 3
- 3
mindspore/core/ir/func_graph.cc View File

@@ -639,11 +639,11 @@ std::list<CNodePtr> FuncGraph::GetOrderedCnodes() {
auto SuccDepends = std::bind(SuccIncludeFV, this_ptr, std::placeholders::_1);

std::list<CNodePtr> cnodes;
auto nodes = mindspore::TopoSort(get_return(), SuccDepends, BelongSameGraph);
auto nodes = mindspore::TopoSort(return_node(), SuccDepends, BelongSameGraph);
for (const auto &node : nodes) {
auto cnode = dyn_cast<CNode>(node);
if (cnode) {
cnodes.push_back(cnode);
if (cnode != nullptr) {
cnodes.emplace_back(std::move(cnode));
}
}
return cnodes;


+ 2
- 1
mindspore/core/ir/func_graph.h View File

@@ -167,7 +167,7 @@ class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfo
// get function graph inputs, but parameters
const std::vector<AnfNodePtr> get_inputs() const final;
// Return the graph's output, or nullptr if not yet deduced.
AnfNodePtr output() const;
AnfNodePtr output() const final;
void set_output(const AnfNodePtr &value, bool force_new_ret = false);

const std::vector<AnfNodePtr> &parameters() const final { return parameters_; }
@@ -252,6 +252,7 @@ class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfo

CNodePtr get_return() const final { return return_; }
void set_return(const CNodePtr &cnode) final { return_ = cnode; }
const CNodePtr &return_node() const { return return_; }

FuncGraphManagerPtr manager() const { return manager_.lock(); }
void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr<FuncGraphManager>(m); }


+ 32
- 97
mindspore/core/ir/graph_utils.cc View File

@@ -50,47 +50,49 @@ static size_t DumpSortingCircleList(const std::deque<AnfNodePtr> &todo, const An
}

std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) {
constexpr size_t kVecReserve = 64;
std::vector<AnfNodePtr> res;
if (root == nullptr) {
return res;
}
res.reserve(kVecReserve);
size_t seen = NewSeenGeneration();
std::deque<AnfNodePtr> todo;
todo.push_back(root);

todo.emplace_back(root);
while (!todo.empty()) {
AnfNodePtr node = todo.back();
AnfNodePtr &node = todo.back();
if (node->extra_seen_ == seen) { // We use extra_seen_ as finish flag
todo.pop_back();
continue;
}
auto incl = include(node);
if (node->seen_ == seen) { // We use seen_ as checking flag
todo.pop_back();
node->extra_seen_ = seen;
if (incl != EXCLUDE) {
res.push_back(node);
res.emplace_back(std::move(node));
}
node->extra_seen_ = seen;
todo.pop_back();
continue;
}
node->seen_ = seen;
if (incl == FOLLOW) {
auto succs = succ(node);
(void)std::copy_if(succs.begin(), succs.end(), std::back_inserter(todo), [seen, &todo](const AnfNodePtr &next) {
for (auto &next : succ(node)) {
if (next == nullptr || next->extra_seen_ == seen) {
return false;
continue;
}
if (next->seen_ != seen) {
return true;
todo.emplace_back(std::move(next));
continue;
}
if (next->func_graph() != nullptr && next->func_graph()->get_return() == next) {
return false;
auto fg = next->func_graph();
if (fg != nullptr && fg->return_node() == next) {
continue;
}
// To dump all nodes in a circle.
MS_LOG(ERROR) << "Graph cycle exists. Circle is: ";
auto circle_len = DumpSortingCircleList(todo, next, seen);
MS_LOG(EXCEPTION) << "Graph cycle exists, size: " << circle_len << ", strike node: " << next->DebugString(2);
});
}
} else if (incl > EXCLUDE) { // Not NOFOLLOW or EXCLUDE
MS_LOG(EXCEPTION) << "The result of include(node) must be one of: \"follow\", \"nofollow\", \"exclude\"";
}
@@ -98,28 +100,25 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
return res;
}

// search the cnodes inside this graph only
std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const std::vector<CNodePtr> &starts) {
std::vector<CNodePtr> todo;
todo.insert(todo.end(), starts.begin(), starts.end());
// Search the cnodes inside this graph only.
std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const CNodePtr &start) {
constexpr size_t kVecReserve = 64;
std::vector<CNodePtr> vec;
vec.reserve(kVecReserve);
vec.emplace_back(start);
auto seen = NewSeenGeneration();
size_t top_idx = 0;
while (top_idx < todo.size()) {
CNodePtr top = todo[top_idx];
top_idx++;
auto inputs = top->inputs();
for (auto &item : inputs) {
if (item->seen_ == seen) {
continue;
}

if (item->isa<CNode>()) {
todo.push_back(item->cast<CNodePtr>());
for (size_t i = 0; i < vec.size(); ++i) {
CNodePtr &node = vec[i];
node->seen_ = seen;
auto &inputs = node->inputs();
for (auto &input : inputs) {
auto input_cnode = input->cast<CNodePtr>();
if (input_cnode != nullptr && input_cnode->seen_ != seen) {
vec.emplace_back(std::move(input_cnode));
}
item->seen_ = seen;
}
}
return todo;
return vec;
}

// search the cnode match the predicate inside this graph only
@@ -192,7 +191,7 @@ std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {

if (IsValueNode<FuncGraph>(node)) {
auto graph = GetValueNode<FuncGraphPtr>(node);
auto ret = graph->get_return();
auto &ret = graph->return_node();
if (ret != nullptr) {
vecs.push_back(ret);
}
@@ -215,7 +214,7 @@ std::vector<AnfNodePtr> SuccDeeperSimple(const AnfNodePtr &node) {

if (IsValueNode<FuncGraph>(node)) {
auto graph = GetValueNode<FuncGraphPtr>(node);
auto ret = graph->get_return();
auto &ret = graph->return_node();
if (ret != nullptr) {
vecs.push_back(ret);
}
@@ -270,8 +269,6 @@ const std::vector<AnfNodePtr> &GetInputs(const AnfNodePtr &node) {
return empty_inputs;
}

IncludeType AlwaysInclude(const AnfNodePtr &) { return FOLLOW; }

IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node) {
if (node->func_graph() == fg) {
return FOLLOW;
@@ -279,66 +276,4 @@ IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node) {
return EXCLUDE;
}
}

FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search, const IncludeFunc &include) {
MS_EXCEPTION_IF_NULL(fg);
Acquire(fg);

auto vec = search(fg->get_return(), include);
for (auto &node : vec) {
MS_EXCEPTION_IF_NULL(node);
Acquire(node);
if (node->func_graph() != nullptr) {
Acquire(node->func_graph());
}
}
}

std::set<FuncGraphPtr> FuncGraphIndex::GetFuncGraphs(const std::string &key) {
std::set<FuncGraphPtr> func_graphs;
if (index_func_graph_.find(key) != index_func_graph_.end()) {
func_graphs = index_func_graph_[key];
}
return func_graphs;
}

std::set<AnfNodePtr> FuncGraphIndex::GetNodes(const std::string &key) {
if (index_node_.find(key) != index_node_.end()) {
return index_node_[key];
}

return std::set<AnfNodePtr>();
}

FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string &key) {
if (GetFuncGraphs(key).empty()) {
return nullptr;
}

auto fg = *GetFuncGraphs(key).begin();
return fg;
}

AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string &key) {
if (GetNodes(key).empty()) {
return nullptr;
}

auto node = *GetNodes(key).begin();
return node;
}

void FuncGraphIndex::Acquire(const FuncGraphPtr &key) {
std::string name = label_manage::Label(key->debug_info());
if (!name.empty()) {
(void)index_func_graph_[name].insert(key);
}
}

void FuncGraphIndex::Acquire(const AnfNodePtr &key) {
std::string name = label_manage::Label(key->debug_info());
if (!name.empty()) {
(void)index_node_[name].insert(key);
}
}
} // namespace mindspore

+ 5
- 32
mindspore/core/ir/graph_utils.h View File

@@ -27,6 +27,7 @@
#include <map>
#include <set>
#include <string>
#include <functional>

#include "ir/anf.h"
#include "ir/primitive.h"
@@ -42,10 +43,7 @@ using FilterFunc = std::function<bool(const AnfNodePtr &)>;
using SuccFunc = std::function<std::vector<AnfNodePtr>(AnfNodePtr)>;
using SearchFunc = std::function<std::vector<AnfNodePtr>(const AnfNodePtr &, const IncludeFunc &)>;
using MatchFunc = std::function<bool(const CNodePtr &)>;

std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include);
std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include);
std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include);
using NodeVisitFunc = std::function<void(const AnfNodePtr &)>;

std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node);
std::vector<AnfNodePtr> SuccDeeperSimple(const AnfNodePtr &node);
@@ -54,49 +52,24 @@ std::vector<AnfNodePtr> SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &

const std::vector<AnfNodePtr> &GetInputs(const AnfNodePtr &node);

IncludeType AlwaysInclude(const AnfNodePtr &node);
inline IncludeType AlwaysInclude(const AnfNodePtr &) { return FOLLOW; }
IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node);

std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude);
std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude);
std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude);

std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include,
const FilterFunc &filter);

class FuncGraphManager;
using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;
std::vector<AnfNodePtr> DeepUsersSearch(const AnfNodePtr &root, const IncludeFunc &include,
const FuncGraphManagerPtr &mng);
std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming,
const IncludeFunc &include = AlwaysInclude);

std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const std::vector<CNodePtr> &starts);
std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const CNodePtr &start);
std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(const FuncGraphPtr &root);

CNodePtr BroadFirstSearchFirstOf(const std::vector<CNodePtr> &starts, const MatchFunc &match_predicate);

class FuncGraphIndex {
public:
explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch,
const IncludeFunc &include = AlwaysInclude);
FuncGraphIndex(const FuncGraphIndex &) = delete;
FuncGraphIndex &operator=(const FuncGraphIndex &) = delete;

virtual ~FuncGraphIndex() {}

std::set<FuncGraphPtr> GetFuncGraphs(const std::string &key);
std::set<AnfNodePtr> GetNodes(const std::string &key);
FuncGraphPtr GetFirstFuncGraph(const std::string &key);
AnfNodePtr GetFirstNode(const std::string &key);

private:
void Acquire(const FuncGraphPtr &key);
void Acquire(const AnfNodePtr &key);

std::map<std::string, std::set<FuncGraphPtr>> index_func_graph_;
std::map<std::string, std::set<AnfNodePtr>> index_node_;
};
} // namespace mindspore

#endif // MINDSPORE_CORE_IR_GRAPH_UTILS_H_

+ 16
- 75
mindspore/core/ir/graph_utils_extends.cc View File

@@ -37,7 +37,10 @@ namespace {
class DeepFirstSearcher : public AnfIrVisitor {
public:
explicit DeepFirstSearcher(const IncludeFunc &include, const FilterFunc &filter = nullptr)
: include_(include), filter_(filter) {}
: include_(include), filter_(filter) {
constexpr size_t kVecReserve = 64;
res_.reserve(kVecReserve);
}
~DeepFirstSearcher() override = default;

std::vector<AnfNodePtr> Search(const AnfNodePtr &root) {
@@ -50,13 +53,10 @@ class DeepFirstSearcher : public AnfIrVisitor {
}

void Visit(const AnfNodePtr &node) override {
MS_EXCEPTION_IF_NULL(node);
if (node->seen_ == seen_) {
if (node == nullptr || node->seen_ == seen_) {
return;
}

node->seen_ = seen_;

auto incl = include_(node);
if (incl == EXCLUDE) {
return;
@@ -82,14 +82,12 @@ class DeepScopedGraphSearcher : public DeepFirstSearcher {
~DeepScopedGraphSearcher() override = default;

void Visit(const CNodePtr &cnode) override {
if (cnode->func_graph() == nullptr) {
auto fg = cnode->func_graph();
if (fg == nullptr) {
return;
}

AnfNodePtr ret = cnode->func_graph()->get_return();
if (ret != nullptr) {
DeepFirstSearcher::Visit(ret);
}
AnfNodePtr ret = fg->return_node();
DeepFirstSearcher::Visit(ret);

auto &inputs = cnode->inputs();
for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
@@ -101,48 +99,18 @@ class DeepScopedGraphSearcher : public DeepFirstSearcher {
if (!IsValueNode<FuncGraph>(vnode)) {
return;
}

auto graph = GetValueNode<FuncGraphPtr>(vnode);
AnfNodePtr ret = graph->get_return();
if (ret != nullptr) {
DeepFirstSearcher::Visit(ret);
}
auto fg = GetValueNode<FuncGraphPtr>(vnode);
AnfNodePtr ret = fg->return_node();
DeepFirstSearcher::Visit(ret);
}

void Visit(const ParameterPtr &param) override {
if (param->func_graph() == nullptr) {
auto fg = param->func_graph();
if (fg == nullptr) {
return;
}

AnfNodePtr ret = param->func_graph()->get_return();
if (ret != nullptr) {
DeepFirstSearcher::Visit(ret);
}
}
};

class DeepUsedGraphSearcher : public DeepFirstSearcher {
public:
explicit DeepUsedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {}
~DeepUsedGraphSearcher() override = default;

void Visit(const CNodePtr &cnode) override {
auto &inputs = cnode->inputs();
for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
DeepFirstSearcher::Visit(*iter);
}
}

void Visit(const ValueNodePtr &vnode) override {
if (!IsValueNode<FuncGraph>(vnode)) {
return;
}

auto graph = GetValueNode<FuncGraphPtr>(vnode);
AnfNodePtr ret = graph->get_return();
if (ret != nullptr) {
DeepFirstSearcher::Visit(ret);
}
AnfNodePtr ret = fg->return_node();
DeepFirstSearcher::Visit(ret);
}
};

@@ -160,24 +128,6 @@ class DeepLinkedGraphSearcher : public DeepFirstSearcher {

void Visit(const ValueNodePtr &) override {}
};

class DeepUsersSearcher : public DeepFirstSearcher {
public:
explicit DeepUsersSearcher(const IncludeFunc &include, const FuncGraphManagerPtr &mng)
: DeepFirstSearcher(include), mng_(mng) {}
~DeepUsersSearcher() override = default;

void Visit(const CNodePtr &cnode) override {
auto &users = mng_->node_users()[cnode];
for (auto iter = users.begin(); iter != users.end(); ++iter) {
DeepFirstSearcher::Visit(iter->first);
}
}
void Visit(const ValueNodePtr &) override {}

private:
FuncGraphManagerPtr mng_;
};
} // namespace

// include for if expand the node the search, filter for if put the node to results.
@@ -190,16 +140,7 @@ std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root,
return DeepFirstSearcher(include, filter).Search(root);
}

std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
return DeepUsedGraphSearcher(include).Search(root);
}

std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
return DeepLinkedGraphSearcher(include).Search(root);
}

std::vector<AnfNodePtr> DeepUsersSearch(const AnfNodePtr &root, const IncludeFunc &include,
const FuncGraphManagerPtr &mng) {
return DeepUsersSearcher(include, mng).Search(root);
}
} // namespace mindspore

+ 85
- 5
tests/ut/cpp/ir/clone_test.cc View File

@@ -27,6 +27,86 @@
#include "base/core_ops.h"

namespace mindspore {
class FuncGraphIndex {
public:
explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch,
const IncludeFunc &include = AlwaysInclude);
FuncGraphIndex(const FuncGraphIndex &) = delete;
FuncGraphIndex &operator=(const FuncGraphIndex &) = delete;

virtual ~FuncGraphIndex() {}

std::set<FuncGraphPtr> GetFuncGraphs(const std::string &key);
std::set<AnfNodePtr> GetNodes(const std::string &key);
FuncGraphPtr GetFirstFuncGraph(const std::string &key);
AnfNodePtr GetFirstNode(const std::string &key);

private:
void Acquire(const FuncGraphPtr &key);
void Acquire(const AnfNodePtr &key);

std::map<std::string, std::set<FuncGraphPtr>> index_func_graph_;
std::map<std::string, std::set<AnfNodePtr>> index_node_;
};

FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search, const IncludeFunc &include) {
MS_EXCEPTION_IF_NULL(fg);
Acquire(fg);
auto vec = search(fg->get_return(), include);
for (auto &node : vec) {
MS_EXCEPTION_IF_NULL(node);
Acquire(node);
if (node->func_graph() != nullptr) {
Acquire(node->func_graph());
}
}
}

std::set<FuncGraphPtr> FuncGraphIndex::GetFuncGraphs(const std::string &key) {
std::set<FuncGraphPtr> func_graphs;
if (index_func_graph_.find(key) != index_func_graph_.end()) {
func_graphs = index_func_graph_[key];
}
return func_graphs;
}

std::set<AnfNodePtr> FuncGraphIndex::GetNodes(const std::string &key) {
if (index_node_.find(key) != index_node_.end()) {
return index_node_[key];
}
return std::set<AnfNodePtr>();
}

FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string &key) {
if (GetFuncGraphs(key).empty()) {
return nullptr;
}
auto fg = *GetFuncGraphs(key).begin();
return fg;
}

AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string &key) {
if (GetNodes(key).empty()) {
return nullptr;
}
auto node = *GetNodes(key).begin();
return node;
}

void FuncGraphIndex::Acquire(const FuncGraphPtr &key) {
std::string name = label_manage::Label(key->debug_info());
if (!name.empty()) {
(void)index_func_graph_[name].insert(key);
}
}

void FuncGraphIndex::Acquire(const AnfNodePtr &key) {
std::string name = label_manage::Label(key->debug_info());
if (!name.empty()) {
(void)index_node_[name].insert(key);
}
}

class TestCloner : public UT::Common {
public:
TestCloner() : getPyFun("gtest_input.ir.clone_test", true) {
@@ -36,7 +116,7 @@ class TestCloner : public UT::Common {
}

FuncGraphPtr GraphForInline() { return nullptr; }
void SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig, const std::vector<AnfNodePtr>& params,
void SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig, const std::vector<AnfNodePtr> &params,
FuncGraphPtr target);

public:
@@ -48,7 +128,7 @@ class TestCloner : public UT::Common {
};

void TestCloner::SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig,
const std::vector<AnfNodePtr>& params, FuncGraphPtr target) {
const std::vector<AnfNodePtr> &params, FuncGraphPtr target) {
auto g = (*cl)[orig];
ASSERT_TRUE(g != target);
ASSERT_TRUE(g == orig);
@@ -59,11 +139,11 @@ void TestCloner::SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphP
AnfNodeSet orig_nodes = AnfNodeSet(DeepLinkedGraphSearch(orig->output()));
AnfNodeSet new_nodes = AnfNodeSet(DeepLinkedGraphSearch(new_root));

for (auto& p : params) {
for (auto &p : params) {
ASSERT_TRUE(new_nodes.contains(p));
}

for (auto& node : orig_nodes) {
for (auto &node : orig_nodes) {
if (node->func_graph() == orig) {
ASSERT_TRUE((*cl)[node]);
}
@@ -93,7 +173,7 @@ TEST_F(TestCloner, test_clone_simple) {
std::vector<Primitive> results = {Primitive(prim::kScalarAdd), Primitive(prim::kScalarMul), Primitive("Return")};
AnfNodeSet d3 = AnfNodeSet(DeepScopedGraphSearch(g3->get_return()));
common = d1 & d3;
for (auto& x : common) {
for (auto &x : common) {
ASSERT_TRUE(x->isa<ValueNode>());
ASSERT_TRUE(find(results.begin(), results.end(), *x->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>()) !=
results.end());


Loading…
Cancel
Save