|
|
|
@@ -291,6 +291,50 @@ AnfNodePtr MakeTupleForSameNodes(const FuncGraphPtr &fg, const CNodePtr &old_upd |
|
|
|
return make_tuple; |
|
|
|
} |
|
|
|
|
|
|
|
// Remove all nodes related to UpdateStates, if they're redundant. |
|
|
|
void EliminateUselessNodesForUpdateStates(const std::vector<CNodePtr> &update_states) { |
|
|
|
if (update_states.empty()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto mgr = GetManager(update_states[0]); |
|
|
|
|
|
|
|
// 1. Remove the use of UpdateState nodes, except the last one. |
|
|
|
for (auto i = update_states.size() - 1; i > 0; i--) { |
|
|
|
auto &us = update_states[i]; |
|
|
|
mgr->Replace(us, us->input(kInputIndex)); |
|
|
|
} |
|
|
|
|
|
|
|
// 2. Remove the Depend users of last UpdateState node. |
|
|
|
auto &node_users = mgr->node_users(); |
|
|
|
auto iter = node_users.find(update_states[0]); |
|
|
|
if (iter == node_users.end()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto &us_users = iter->second; |
|
|
|
if (us_users.size() < 2) { |
|
|
|
return; |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> depend_nodes; |
|
|
|
for (auto &user : us_users) { |
|
|
|
if (IsPrimitiveCNode(user.first, prim::kPrimDepend) && user.second == kAttachIndex) { |
|
|
|
depend_nodes.emplace_back(user.first); |
|
|
|
} |
|
|
|
} |
|
|
|
if (depend_nodes.empty()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
ssize_t end = 0; |
|
|
|
// If all users are Depend CNode. |
|
|
|
if (depend_nodes.size() == us_users.size()) { |
|
|
|
end = 1; |
|
|
|
} |
|
|
|
for (ssize_t i = depend_nodes.size() - 1; i >= end; i--) { |
|
|
|
const auto &depend_node = depend_nodes[i]; |
|
|
|
const auto &depend_cnode = depend_node->cast<CNodePtr>(); |
|
|
|
mgr->Replace(depend_cnode, depend_cnode->input(kInputIndex)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Eliminate UpdateStates for consecutive Loads. |
|
|
|
// Convert: |
|
|
|
// x1 = Load(x1, u) |
|
|
|
@@ -336,10 +380,9 @@ AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const |
|
|
|
mgr->SetEdge(load, kAttachIndex, input_monad); |
|
|
|
} |
|
|
|
} |
|
|
|
for (auto i = update_states.size() - 1; i > 0; i--) { |
|
|
|
auto &us = update_states[i]; |
|
|
|
mgr->Replace(us, us->input(kInputIndex)); |
|
|
|
} |
|
|
|
|
|
|
|
EliminateUselessNodesForUpdateStates(update_states); |
|
|
|
|
|
|
|
if (make_tuple_inputs.size() == 1) { |
|
|
|
// This should not happen. |
|
|
|
MS_LOG(WARNING) << "No loads for " << old_update_state->DebugString(2); |
|
|
|
|