|
|
|
@@ -22,6 +22,7 @@ |
|
|
|
#include <functional> |
|
|
|
#include <memory> |
|
|
|
#include <set> |
|
|
|
#include <unordered_set> |
|
|
|
#include <tuple> |
|
|
|
#include <unordered_map> |
|
|
|
#include <utility> |
|
|
|
@@ -39,43 +40,39 @@ namespace compile { |
|
|
|
ConvertCache g_ConvertCache; |
|
|
|
void ClearConvertCache() { g_ConvertCache.clear(); } |
|
|
|
|
|
|
|
namespace { |
|
|
|
// Return the list of nodes whose values are required beyond this segment. |
|
|
|
// Arguments: |
|
|
|
// lst: list of nodes (the segment) |
|
|
|
// nodes: list of nodes in the segment |
|
|
|
// users: dict mapping each node to its users (globally) |
|
|
|
// seen: set of nodes that are part of the segment |
|
|
|
AnfNodePtrList GetOutput(const AnfNodePtrList &lst, const NodeUsersMap &users, const std::vector<AnfNodePtr> &seen) { |
|
|
|
AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users, |
|
|
|
const std::unordered_set<AnfNodePtr> &seen) { |
|
|
|
AnfNodePtrList output; |
|
|
|
if (users.size() == 0) { |
|
|
|
return output; |
|
|
|
} |
|
|
|
|
|
|
|
(void)std::transform( |
|
|
|
std::begin(lst), std::end(lst), std::back_inserter(output), [&users, &seen](AnfNodePtr n) -> AnfNodePtr { |
|
|
|
auto usersn = users.find(n); |
|
|
|
bool is_referred_out_of_segment = std::any_of( |
|
|
|
std::begin(usersn->second), std::end(usersn->second), [&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool { |
|
|
|
return std::find(std::begin(seen), std::end(seen), u.first) == std::end(seen); |
|
|
|
}); |
|
|
|
if (n->isa<CNode>() && is_referred_out_of_segment) { |
|
|
|
return n; |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
}); |
|
|
|
|
|
|
|
// remove nullptr |
|
|
|
for (auto it = output.begin(); it != output.end();) { |
|
|
|
if (*it == nullptr) { |
|
|
|
it = output.erase(it); |
|
|
|
} else { |
|
|
|
++it; |
|
|
|
for (auto &node : nodes) { |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto iter = users.find(node); |
|
|
|
if (iter == users.end()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto &node_users = iter->second; |
|
|
|
const bool has_outer_user = std::any_of( |
|
|
|
std::begin(node_users), std::end(node_users), [&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool { |
|
|
|
const bool is_outer_user = (seen.find(u.first) == seen.end()); |
|
|
|
return is_outer_user && !(IsPrimitiveCNode(u.first, prim::kPrimUpdateState) && u.second > 2); |
|
|
|
}); |
|
|
|
if (has_outer_user) { |
|
|
|
output.emplace_back(node); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return output; |
|
|
|
} |
|
|
|
|
|
|
|
namespace { |
|
|
|
AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *const inputs_ptr, |
|
|
|
AnfNodePtrToAnfNodePtrMap *eqv_ptr) { |
|
|
|
MS_EXCEPTION_IF_NULL(fg); |
|
|
|
@@ -129,6 +126,15 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr |
|
|
|
for (size_t i = 2; i < inps.size(); ++i) { |
|
|
|
args.emplace_back(NewValueNode(MakeValue(0))); |
|
|
|
} |
|
|
|
} else if (IsPrimitive(fn, prim::kPrimUpdateState)) { |
|
|
|
args.emplace_back(RefSubGraphNode(fg, inps[1], &inputs, &eqv)); |
|
|
|
args.emplace_back(RefSubGraphNode(fg, inps[2], &inputs, &eqv)); |
|
|
|
for (size_t i = 3; i < inps.size(); ++i) { |
|
|
|
auto &input = inps[i]; |
|
|
|
if (eqv.find(input) != eqv.end()) { |
|
|
|
args.emplace_back(RefSubGraphNode(fg, input, &inputs, &eqv)); |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), |
|
|
|
[&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); }); |
|
|
|
@@ -138,8 +144,8 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr |
|
|
|
eqv[n]->set_abstract(n->abstract()); |
|
|
|
eqv[n]->set_kernel_info(n->kernel_info_ptr()); |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> eqv_keys; |
|
|
|
(void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys), |
|
|
|
std::unordered_set<AnfNodePtr> eqv_keys; |
|
|
|
(void)std::transform(std::begin(eqv), std::end(eqv), std::inserter(eqv_keys, eqv_keys.end()), |
|
|
|
[](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; }); |
|
|
|
auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys); |
|
|
|
AnfNodePtr fg_output; |
|
|
|
|