Browse Source

Do cse graph by graph

tags/v0.2.0-alpha
lyfne 5 years ago
parent
commit
b7076d260e
2 changed files with 16 additions and 15 deletions
  1. +15
    -15
      mindspore/ccsrc/optimizer/cse.cc
  2. +1
    -0
      mindspore/ccsrc/optimizer/cse.h

+ 15
- 15
mindspore/ccsrc/optimizer/cse.cc View File

@@ -40,14 +40,14 @@ BasePtr AbsOf(const AnfNodePtr &node) {
return node_abs;
}

namespace {
void BuildOrderGroup(const FuncGraphManagerPtr manager, std::vector<std::size_t> *const order_group,
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> *groups) {
MS_EXCEPTION_IF_NULL(order_group);

std::unordered_map<AnfNodePtr, std::size_t> hashes;
bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
bool changed = false;
for (FuncGraphPtr fg : manager->func_graphs()) {
MS_EXCEPTION_IF_NULL(fg);
std::vector<std::size_t> order_group;
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> groups;
std::unordered_map<AnfNodePtr, std::size_t> hashes;

std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
for (auto node : toposet) {
MS_EXCEPTION_IF_NULL(node);
@@ -75,17 +75,20 @@ void BuildOrderGroup(const FuncGraphManagerPtr manager, std::vector<std::size_t>
}

hashes[node] = h;
if (groups->find(h) == groups->end()) {
if (groups.find(h) == groups.end()) {
std::vector<AnfNodePtr> innervec({node});
(*groups)[h] = innervec;
order_group->emplace_back(h);
groups[h] = innervec;
order_group.emplace_back(h);
} else {
(*groups)[h].push_back(node);
groups[h].push_back(node);
}
}

changed = DoReplace(manager, order_group, &groups) || changed;
}

return changed;
}
} // namespace

bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(main);
@@ -177,10 +180,7 @@ bool CSE::Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(root);

std::unordered_map<std::size_t, std::vector<AnfNodePtr>> groups;
std::vector<std::size_t> order_group;
BuildOrderGroup(manager, &order_group, &groups);
return DoReplace(manager, order_group, &groups);
return BuildOrderGroupAndDoReplace(manager);
}
} // namespace opt
} // namespace mindspore

+ 1
- 0
mindspore/ccsrc/optimizer/cse.h View File

@@ -46,6 +46,7 @@ class CSE {
bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const;

private:
bool BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const;
bool DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::size_t> &order_group,
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> *groups) const;
bool report_changes_;


Loading…
Cancel
Save