Browse Source

[auto-monad] Fix a bug in updatestate_eliminate

All Load nodes in a MakeTuple should use same monad if we are going to merge UpdateStates for them.
tags/v1.2.0-rc1
He Wei 4 years ago
parent
commit
09e7733861
1 changed files with 55 additions and 53 deletions
  1. +55
    -53
      mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc

+ 55
- 53
mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc View File

@@ -217,48 +217,50 @@ AnfNodePtr EliminateUpdateStateWithMakeTupleFunc(const CNodePtr &update_state, c
return nullptr;
}

size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads);
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple,
std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads);
void GetLoadsFollowLoad(const CNodePtr &update_state, const CNodePtr &load, std::vector<CNodePtr> *update_states,
std::vector<CNodePtr> *loads);
void GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *update_states,
std::vector<CNodePtr> *loads);

// Search consecutive load nodes from UpdateState node.
size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *update_states,
std::vector<CNodePtr> *loads) {
void GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *update_states,
std::vector<CNodePtr> *loads) {
auto &attach = update_state->input(kAttachIndex);
if (IsPrimitiveCNode(attach, prim::kPrimLoad)) {
update_states->emplace_back(update_state);
return GetLoadsFollowLoad(attach->cast<CNodePtr>(), update_states, loads);
GetLoadsFollowLoad(update_state, attach->cast<CNodePtr>(), update_states, loads);
} else if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), update_states, loads);
}
if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
update_states->emplace_back(update_state);
return GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), update_states, loads);
}
return 0;
}

size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads) {
void GetLoadsFollowLoad(const CNodePtr &update_state, const CNodePtr &load, std::vector<CNodePtr> *update_states,
std::vector<CNodePtr> *loads) {
update_states->emplace_back(update_state);
loads->emplace_back(load);
auto &load_attach = load->input(kAttachIndex);
if (IsPrimitiveCNode(load_attach, prim::kPrimUpdateState)) {
return GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), update_states, loads) + 1;
GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), update_states, loads);
}
return 1;
}

size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple,
std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads) {
void GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *update_states,
std::vector<CNodePtr> *loads) {
if (!OnlyUpdateStateUse(update_state, make_tuple)) {
// UpdateState should be the only user of
return 0;
// UpdateState should be the only user of make_tuple.
return;
}
auto &inputs = make_tuple->inputs();
bool is_all_load = std::all_of(inputs.begin() + 1, inputs.end(),
[](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimLoad); });
const auto &monad = update_state->input(kInputIndex);
bool is_all_load = std::all_of(inputs.begin() + 1, inputs.end(), [&monad](const AnfNodePtr &node) {
// Tuple element should be Load and use same monad that UpdateState used.
return (IsPrimitiveCNode(node, prim::kPrimLoad) && node->cast<CNodePtr>()->input(kAttachIndex) == monad);
});
if (!is_all_load) {
// Stop if not all tuple elements are load node.
return 0;
// Stop if not all tuple elements are load nodes and use same monad.
return;
}
// Add load nodes from tuple elements.
// Add update_state and load nodes.
update_states->emplace_back(update_state);
for (size_t i = 1; i < inputs.size(); ++i) {
auto &element = inputs.at(i);
loads->emplace_back(element->cast<CNodePtr>());
@@ -266,9 +268,8 @@ size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tu
// Follow prev update state if found.
auto prev_node = update_state->input(kInputIndex);
if (IsPrimitiveCNode(prev_node, prim::kPrimUpdateState)) {
return GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), update_states, loads) + 1;
GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), update_states, loads);
}
return 1;
}

// Create a MakeTuple node before UpdateState for same nodes, if there are more than 1 node used.
@@ -549,63 +550,64 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode
return nullptr;
}
auto &attach = update_state_node->input(kAttachIndex);

// Handle UpdateState(u, Depend(...)).
if (IsPrimitiveCNode(attach, prim::kPrimDepend)) {
return EliminateUpdateStateWithDepend(update_state_node, attach->cast<CNodePtr>());
}

// Handle UpdateState(u, Partial(...)).
if (IsPrimitiveCNode(attach, prim::kPrimPartial)) {
return EliminateUpdateStateOnlyUsedNode(update_state_node, attach);
}
const bool attach_is_load = IsPrimitiveCNode(attach, prim::kPrimLoad);
if (attach_is_load) {
auto new_node = EliminateUpdateStateOnlyUsedNode(update_state_node, attach);
if (new_node != nullptr) {
return new_node;
}
// We should continue check when useless Load not found,
// since GetLoadsFromUpdateState() also need to check Load.
}

const bool attach_is_assign = IsPrimitiveCNode(attach, prim::kPrimAssign);
if (attach_is_assign) {
// Handle UpdateState(u, Assign(...)).
if (IsPrimitiveCNode(attach, prim::kPrimAssign)) {
auto new_node = EliminateUpdateStateBetweenAssigns(update_state_node, attach);
if (new_node != nullptr) {
return new_node;
}
new_node = EliminateUpdateStateBetweenMakeTupleAssign(update_state_node, attach);
return EliminateUpdateStateBetweenMakeTupleAssign(update_state_node, attach);
}

// Handle UpdateState(u, Load(...)).
const bool attach_is_load = IsPrimitiveCNode(attach, prim::kPrimLoad);
if (attach_is_load) {
auto new_node = EliminateUpdateStateOnlyUsedNode(update_state_node, attach);
if (new_node != nullptr) {
return new_node;
}
}

// Handle UpdateState(u, MakeTuple(...)).
const bool attach_is_tuple = IsPrimitiveCNode(attach, prim::kPrimMakeTuple);
if (attach_is_tuple) {
auto new_node = EliminateMakeTupleWithDeadNode(update_state_node, attach->cast<CNodePtr>());
auto make_tuple = attach->cast<CNodePtr>();
auto new_node = EliminateMakeTupleWithDeadNode(update_state_node, make_tuple);
if (new_node != nullptr) {
return new_node;
}
// We should continue check when MakeTuple with "Dead Node" not found,
// since GetLoadsFromUpdateState() also need to check MakeTuple.

new_node = EliminateUpdateStateWithMakeTupleFunc(update_state_node, attach->cast<CNodePtr>());
new_node = EliminateUpdateStateWithMakeTupleFunc(update_state_node, make_tuple);
if (new_node != nullptr) {
return new_node;
}

new_node = EliminateUpdateStateBetweenAssignMakeTuple(update_state_node, attach->cast<CNodePtr>());
new_node = EliminateUpdateStateBetweenAssignMakeTuple(update_state_node, make_tuple);
if (new_node != nullptr) {
return new_node;
}
}
std::vector<CNodePtr> update_states;
std::vector<CNodePtr> loads;
if (GetLoadsFromUpdateState(update_state_node, &update_states, &loads) > 1 && loads.size() > 1) {
return EliminateUpdateStateForLoads(update_state_node, update_states, loads);
// Merge UpdateStates for Loads.
if (attach_is_load || attach_is_tuple) {
std::vector<CNodePtr> update_states;
std::vector<CNodePtr> loads;
GetLoadsFromUpdateState(update_state_node, &update_states, &loads);
if (update_states.size() > 1 && loads.size() > 1) {
return EliminateUpdateStateForLoads(update_state_node, update_states, loads);
}
return nullptr;
}
// Eliminate UpdateStates that attaches a no-side-effect node.
if (!attach_is_load && !attach_is_tuple) {
return EliminateUpdateStateForPureNode(update_state_node, attach);
}
return nullptr;
return EliminateUpdateStateForPureNode(update_state_node, attach);
}

// Eliminate Monad parameter for switch call.


Loading…
Cancel
Save