diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc index fc6c96cb8b..e16143b968 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc @@ -217,48 +217,50 @@ AnfNodePtr EliminateUpdateStateWithMakeTupleFunc(const CNodePtr &update_state, c return nullptr; } -size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector *update_states, std::vector *loads); -size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, - std::vector *update_states, std::vector *loads); +void GetLoadsFollowLoad(const CNodePtr &update_state, const CNodePtr &load, std::vector *update_states, + std::vector *loads); +void GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector *update_states, + std::vector *loads); // Search consecutive load nodes from UpdateState node. -size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector *update_states, - std::vector *loads) { +void GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector *update_states, + std::vector *loads) { auto &attach = update_state->input(kAttachIndex); if (IsPrimitiveCNode(attach, prim::kPrimLoad)) { - update_states->emplace_back(update_state); - return GetLoadsFollowLoad(attach->cast(), update_states, loads); + GetLoadsFollowLoad(update_state, attach->cast(), update_states, loads); + } else if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) { + GetLoadsFollowTuple(update_state, attach->cast(), update_states, loads); } - if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) { - update_states->emplace_back(update_state); - return GetLoadsFollowTuple(update_state, attach->cast(), update_states, loads); - } - return 0; } -size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector *update_states, std::vector *loads) { +void GetLoadsFollowLoad(const CNodePtr &update_state, const CNodePtr &load, std::vector *update_states, + std::vector *loads) { + update_states->emplace_back(update_state); loads->emplace_back(load); auto &load_attach = load->input(kAttachIndex); if (IsPrimitiveCNode(load_attach, prim::kPrimUpdateState)) { - return GetLoadsFromUpdateState(load_attach->cast(), update_states, loads) + 1; + GetLoadsFromUpdateState(load_attach->cast(), update_states, loads); } - return 1; } -size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, - std::vector *update_states, std::vector *loads) { +void GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector *update_states, + std::vector *loads) { if (!OnlyUpdateStateUse(update_state, make_tuple)) { - // UpdateState should be the only user of - return 0; + // UpdateState should be the only user of make_tuple. + return; } auto &inputs = make_tuple->inputs(); - bool is_all_load = std::all_of(inputs.begin() + 1, inputs.end(), - [](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimLoad); }); + const auto &monad = update_state->input(kInputIndex); + bool is_all_load = std::all_of(inputs.begin() + 1, inputs.end(), [&monad](const AnfNodePtr &node) { + // Tuple element should be Load and use same monad that UpdateState used. + return (IsPrimitiveCNode(node, prim::kPrimLoad) && node->cast()->input(kAttachIndex) == monad); + }); if (!is_all_load) { - // Stop if not all tuple elements are load node. - return 0; + // Stop if not all tuple elements are load nodes and use same monad. + return; } - // Add load nodes from tuple elements. + // Add update_state and load nodes. + update_states->emplace_back(update_state); for (size_t i = 1; i < inputs.size(); ++i) { auto &element = inputs.at(i); loads->emplace_back(element->cast()); @@ -266,9 +268,8 @@ size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tu // Follow prev update state if found. auto prev_node = update_state->input(kInputIndex); if (IsPrimitiveCNode(prev_node, prim::kPrimUpdateState)) { - return GetLoadsFromUpdateState(prev_node->cast(), update_states, loads) + 1; + GetLoadsFromUpdateState(prev_node->cast(), update_states, loads); } - return 1; } // Create a MakeTuple node before UpdateState for same nodes, if there are more than 1 node used. @@ -549,63 +550,64 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode return nullptr; } auto &attach = update_state_node->input(kAttachIndex); + + // Handle UpdateState(u, Depend(...)). if (IsPrimitiveCNode(attach, prim::kPrimDepend)) { return EliminateUpdateStateWithDepend(update_state_node, attach->cast()); } + + // Handle UpdateState(u, Partial(...)). if (IsPrimitiveCNode(attach, prim::kPrimPartial)) { return EliminateUpdateStateOnlyUsedNode(update_state_node, attach); } - const bool attach_is_load = IsPrimitiveCNode(attach, prim::kPrimLoad); - if (attach_is_load) { - auto new_node = EliminateUpdateStateOnlyUsedNode(update_state_node, attach); - if (new_node != nullptr) { - return new_node; - } - // We should continue check when useless Load not found, - // since GetLoadsFromUpdateState() also need to check Load. - } - const bool attach_is_assign = IsPrimitiveCNode(attach, prim::kPrimAssign); - if (attach_is_assign) { + // Handle UpdateState(u, Assign(...)). + if (IsPrimitiveCNode(attach, prim::kPrimAssign)) { auto new_node = EliminateUpdateStateBetweenAssigns(update_state_node, attach); if (new_node != nullptr) { return new_node; } - new_node = EliminateUpdateStateBetweenMakeTupleAssign(update_state_node, attach); + return EliminateUpdateStateBetweenMakeTupleAssign(update_state_node, attach); + } + + // Handle UpdateState(u, Load(...)). + const bool attach_is_load = IsPrimitiveCNode(attach, prim::kPrimLoad); + if (attach_is_load) { + auto new_node = EliminateUpdateStateOnlyUsedNode(update_state_node, attach); if (new_node != nullptr) { return new_node; } } + // Handle UpdateState(u, MakeTuple(...)). const bool attach_is_tuple = IsPrimitiveCNode(attach, prim::kPrimMakeTuple); if (attach_is_tuple) { - auto new_node = EliminateMakeTupleWithDeadNode(update_state_node, attach->cast()); + auto make_tuple = attach->cast(); + auto new_node = EliminateMakeTupleWithDeadNode(update_state_node, make_tuple); if (new_node != nullptr) { return new_node; } - // We should continue check when MakeTuple with "Dead Node" not found, - // since GetLoadsFromUpdateState() also need to check MakeTuple. - - new_node = EliminateUpdateStateWithMakeTupleFunc(update_state_node, attach->cast()); + new_node = EliminateUpdateStateWithMakeTupleFunc(update_state_node, make_tuple); if (new_node != nullptr) { return new_node; } - - new_node = EliminateUpdateStateBetweenAssignMakeTuple(update_state_node, attach->cast()); + new_node = EliminateUpdateStateBetweenAssignMakeTuple(update_state_node, make_tuple); if (new_node != nullptr) { return new_node; } } - std::vector update_states; - std::vector loads; - if (GetLoadsFromUpdateState(update_state_node, &update_states, &loads) > 1 && loads.size() > 1) { - return EliminateUpdateStateForLoads(update_state_node, update_states, loads); + // Merge UpdateStates for Loads. + if (attach_is_load || attach_is_tuple) { + std::vector update_states; + std::vector loads; + GetLoadsFromUpdateState(update_state_node, &update_states, &loads); + if (update_states.size() > 1 && loads.size() > 1) { + return EliminateUpdateStateForLoads(update_state_node, update_states, loads); + } + return nullptr; } // Eliminate UpdateStates that attaches a no-side-effect node. - if (!attach_is_load && !attach_is_tuple) { - return EliminateUpdateStateForPureNode(update_state_node, attach); - } - return nullptr; + return EliminateUpdateStateForPureNode(update_state_node, attach); } // Eliminate Monad parameter for switch call.