|
|
|
@@ -217,48 +217,50 @@ AnfNodePtr EliminateUpdateStateWithMakeTupleFunc(const CNodePtr &update_state, c |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads); |
|
|
|
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, |
|
|
|
std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads); |
|
|
|
void GetLoadsFollowLoad(const CNodePtr &update_state, const CNodePtr &load, std::vector<CNodePtr> *update_states, |
|
|
|
std::vector<CNodePtr> *loads); |
|
|
|
void GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *update_states, |
|
|
|
std::vector<CNodePtr> *loads); |
|
|
|
|
|
|
|
// Search consecutive load nodes from UpdateState node. |
|
|
|
size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *update_states, |
|
|
|
std::vector<CNodePtr> *loads) { |
|
|
|
void GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *update_states, |
|
|
|
std::vector<CNodePtr> *loads) { |
|
|
|
auto &attach = update_state->input(kAttachIndex); |
|
|
|
if (IsPrimitiveCNode(attach, prim::kPrimLoad)) { |
|
|
|
update_states->emplace_back(update_state); |
|
|
|
return GetLoadsFollowLoad(attach->cast<CNodePtr>(), update_states, loads); |
|
|
|
GetLoadsFollowLoad(update_state, attach->cast<CNodePtr>(), update_states, loads); |
|
|
|
} else if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) { |
|
|
|
GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), update_states, loads); |
|
|
|
} |
|
|
|
if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) { |
|
|
|
update_states->emplace_back(update_state); |
|
|
|
return GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), update_states, loads); |
|
|
|
} |
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads) { |
|
|
|
void GetLoadsFollowLoad(const CNodePtr &update_state, const CNodePtr &load, std::vector<CNodePtr> *update_states, |
|
|
|
std::vector<CNodePtr> *loads) { |
|
|
|
update_states->emplace_back(update_state); |
|
|
|
loads->emplace_back(load); |
|
|
|
auto &load_attach = load->input(kAttachIndex); |
|
|
|
if (IsPrimitiveCNode(load_attach, prim::kPrimUpdateState)) { |
|
|
|
return GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), update_states, loads) + 1; |
|
|
|
GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), update_states, loads); |
|
|
|
} |
|
|
|
return 1; |
|
|
|
} |
|
|
|
|
|
|
|
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, |
|
|
|
std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads) { |
|
|
|
void GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *update_states, |
|
|
|
std::vector<CNodePtr> *loads) { |
|
|
|
if (!OnlyUpdateStateUse(update_state, make_tuple)) { |
|
|
|
// UpdateState should be the only user of |
|
|
|
return 0; |
|
|
|
// UpdateState should be the only user of make_tuple. |
|
|
|
return; |
|
|
|
} |
|
|
|
auto &inputs = make_tuple->inputs(); |
|
|
|
bool is_all_load = std::all_of(inputs.begin() + 1, inputs.end(), |
|
|
|
[](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimLoad); }); |
|
|
|
const auto &monad = update_state->input(kInputIndex); |
|
|
|
bool is_all_load = std::all_of(inputs.begin() + 1, inputs.end(), [&monad](const AnfNodePtr &node) { |
|
|
|
// Tuple element should be Load and use same monad that UpdateState used. |
|
|
|
return (IsPrimitiveCNode(node, prim::kPrimLoad) && node->cast<CNodePtr>()->input(kAttachIndex) == monad); |
|
|
|
}); |
|
|
|
if (!is_all_load) { |
|
|
|
// Stop if not all tuple elements are load node. |
|
|
|
return 0; |
|
|
|
// Stop if not all tuple elements are load nodes and use same monad. |
|
|
|
return; |
|
|
|
} |
|
|
|
// Add load nodes from tuple elements. |
|
|
|
// Add update_state and load nodes. |
|
|
|
update_states->emplace_back(update_state); |
|
|
|
for (size_t i = 1; i < inputs.size(); ++i) { |
|
|
|
auto &element = inputs.at(i); |
|
|
|
loads->emplace_back(element->cast<CNodePtr>()); |
|
|
|
@@ -266,9 +268,8 @@ size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tu |
|
|
|
// Follow prev update state if found. |
|
|
|
auto prev_node = update_state->input(kInputIndex); |
|
|
|
if (IsPrimitiveCNode(prev_node, prim::kPrimUpdateState)) { |
|
|
|
return GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), update_states, loads) + 1; |
|
|
|
GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), update_states, loads); |
|
|
|
} |
|
|
|
return 1; |
|
|
|
} |
|
|
|
|
|
|
|
// Create a MakeTuple node before UpdateState for same nodes, if there are more than 1 node used. |
|
|
|
@@ -549,63 +550,64 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto &attach = update_state_node->input(kAttachIndex); |
|
|
|
|
|
|
|
// Handle UpdateState(u, Depend(...)). |
|
|
|
if (IsPrimitiveCNode(attach, prim::kPrimDepend)) { |
|
|
|
return EliminateUpdateStateWithDepend(update_state_node, attach->cast<CNodePtr>()); |
|
|
|
} |
|
|
|
|
|
|
|
// Handle UpdateState(u, Partial(...)). |
|
|
|
if (IsPrimitiveCNode(attach, prim::kPrimPartial)) { |
|
|
|
return EliminateUpdateStateOnlyUsedNode(update_state_node, attach); |
|
|
|
} |
|
|
|
const bool attach_is_load = IsPrimitiveCNode(attach, prim::kPrimLoad); |
|
|
|
if (attach_is_load) { |
|
|
|
auto new_node = EliminateUpdateStateOnlyUsedNode(update_state_node, attach); |
|
|
|
if (new_node != nullptr) { |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
// We should continue check when useless Load not found, |
|
|
|
// since GetLoadsFromUpdateState() also need to check Load. |
|
|
|
} |
|
|
|
|
|
|
|
const bool attach_is_assign = IsPrimitiveCNode(attach, prim::kPrimAssign); |
|
|
|
if (attach_is_assign) { |
|
|
|
// Handle UpdateState(u, Assign(...)). |
|
|
|
if (IsPrimitiveCNode(attach, prim::kPrimAssign)) { |
|
|
|
auto new_node = EliminateUpdateStateBetweenAssigns(update_state_node, attach); |
|
|
|
if (new_node != nullptr) { |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
new_node = EliminateUpdateStateBetweenMakeTupleAssign(update_state_node, attach); |
|
|
|
return EliminateUpdateStateBetweenMakeTupleAssign(update_state_node, attach); |
|
|
|
} |
|
|
|
|
|
|
|
// Handle UpdateState(u, Load(...)). |
|
|
|
const bool attach_is_load = IsPrimitiveCNode(attach, prim::kPrimLoad); |
|
|
|
if (attach_is_load) { |
|
|
|
auto new_node = EliminateUpdateStateOnlyUsedNode(update_state_node, attach); |
|
|
|
if (new_node != nullptr) { |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Handle UpdateState(u, MakeTuple(...)). |
|
|
|
const bool attach_is_tuple = IsPrimitiveCNode(attach, prim::kPrimMakeTuple); |
|
|
|
if (attach_is_tuple) { |
|
|
|
auto new_node = EliminateMakeTupleWithDeadNode(update_state_node, attach->cast<CNodePtr>()); |
|
|
|
auto make_tuple = attach->cast<CNodePtr>(); |
|
|
|
auto new_node = EliminateMakeTupleWithDeadNode(update_state_node, make_tuple); |
|
|
|
if (new_node != nullptr) { |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
// We should continue check when MakeTuple with "Dead Node" not found, |
|
|
|
// since GetLoadsFromUpdateState() also need to check MakeTuple. |
|
|
|
|
|
|
|
new_node = EliminateUpdateStateWithMakeTupleFunc(update_state_node, attach->cast<CNodePtr>()); |
|
|
|
new_node = EliminateUpdateStateWithMakeTupleFunc(update_state_node, make_tuple); |
|
|
|
if (new_node != nullptr) { |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
|
|
|
|
new_node = EliminateUpdateStateBetweenAssignMakeTuple(update_state_node, attach->cast<CNodePtr>()); |
|
|
|
new_node = EliminateUpdateStateBetweenAssignMakeTuple(update_state_node, make_tuple); |
|
|
|
if (new_node != nullptr) { |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
} |
|
|
|
std::vector<CNodePtr> update_states; |
|
|
|
std::vector<CNodePtr> loads; |
|
|
|
if (GetLoadsFromUpdateState(update_state_node, &update_states, &loads) > 1 && loads.size() > 1) { |
|
|
|
return EliminateUpdateStateForLoads(update_state_node, update_states, loads); |
|
|
|
// Merge UpdateStates for Loads. |
|
|
|
if (attach_is_load || attach_is_tuple) { |
|
|
|
std::vector<CNodePtr> update_states; |
|
|
|
std::vector<CNodePtr> loads; |
|
|
|
GetLoadsFromUpdateState(update_state_node, &update_states, &loads); |
|
|
|
if (update_states.size() > 1 && loads.size() > 1) { |
|
|
|
return EliminateUpdateStateForLoads(update_state_node, update_states, loads); |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
// Eliminate UpdateStates that attaches a no-side-effect node. |
|
|
|
if (!attach_is_load && !attach_is_tuple) { |
|
|
|
return EliminateUpdateStateForPureNode(update_state_node, attach); |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
return EliminateUpdateStateForPureNode(update_state_node, attach); |
|
|
|
} |
|
|
|
|
|
|
|
// Eliminate Monad parameter for switch call. |
|
|
|
|