|
|
|
@@ -40,14 +40,14 @@ BasePtr AbsOf(const AnfNodePtr &node) { |
|
|
|
return node_abs; |
|
|
|
} |
|
|
|
|
|
|
|
namespace { |
|
|
|
void BuildOrderGroup(const FuncGraphManagerPtr manager, std::vector<std::size_t> *const order_group, |
|
|
|
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> *groups) { |
|
|
|
MS_EXCEPTION_IF_NULL(order_group); |
|
|
|
|
|
|
|
std::unordered_map<AnfNodePtr, std::size_t> hashes; |
|
|
|
bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { |
|
|
|
bool changed = false; |
|
|
|
for (FuncGraphPtr fg : manager->func_graphs()) { |
|
|
|
MS_EXCEPTION_IF_NULL(fg); |
|
|
|
std::vector<std::size_t> order_group; |
|
|
|
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> groups; |
|
|
|
std::unordered_map<AnfNodePtr, std::size_t> hashes; |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return()); |
|
|
|
for (auto node : toposet) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
@@ -75,17 +75,20 @@ void BuildOrderGroup(const FuncGraphManagerPtr manager, std::vector<std::size_t> |
|
|
|
} |
|
|
|
|
|
|
|
hashes[node] = h; |
|
|
|
if (groups->find(h) == groups->end()) { |
|
|
|
if (groups.find(h) == groups.end()) { |
|
|
|
std::vector<AnfNodePtr> innervec({node}); |
|
|
|
(*groups)[h] = innervec; |
|
|
|
order_group->emplace_back(h); |
|
|
|
groups[h] = innervec; |
|
|
|
order_group.emplace_back(h); |
|
|
|
} else { |
|
|
|
(*groups)[h].push_back(node); |
|
|
|
groups[h].push_back(node); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
changed = DoReplace(manager, order_group, &groups) || changed; |
|
|
|
} |
|
|
|
|
|
|
|
return changed; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { |
|
|
|
MS_EXCEPTION_IF_NULL(main); |
|
|
|
@@ -177,10 +180,7 @@ bool CSE::Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
manager->AddFuncGraph(root); |
|
|
|
|
|
|
|
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> groups; |
|
|
|
std::vector<std::size_t> order_group; |
|
|
|
BuildOrderGroup(manager, &order_group, &groups); |
|
|
|
return DoReplace(manager, order_group, &groups); |
|
|
|
return BuildOrderGroupAndDoReplace(manager); |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |