From 25664a966c6ef52c3760ea12a020587dd236e1cf Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Tue, 30 Mar 2021 19:47:35 +0800 Subject: [PATCH] Return new UpdateState with input when match UpdateState(, MakeTuple(input,FuncGraph)). Replace US with its monad input in DeleteLoadUserUpdateState(). --- .../optimizer/auto_monad_eliminate.cc | 36 +++++++++++-------- .../optimizer/irpass/updatestate_eliminate.cc | 26 ++++++++++---- 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc index be31d880e9..2424df355e 100644 --- a/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc @@ -110,11 +110,11 @@ std::vector> SplitGroup(const std::vector &topos // u3 = UpdateState(u2, b) //==> // delete the UpdateState -void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user, - const AnfNodePtr &load) { - const auto &load_cnode = load->cast(); - const auto &u = load_cnode->input(2); - manager->Replace(load_user, u); +void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user) { + const auto &update_state_cnode = load_user->cast(); + constexpr size_t monad_index = 1; + const auto &monad = update_state_cnode->input(monad_index); + manager->Replace(load_user, monad); } // Pattern2====================================== @@ -164,14 +164,17 @@ void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGrap manager->Replace(make_tuple, new_make_tuple); } -void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) { +bool ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) { + bool change = false; auto load_users = manager->node_users()[load]; for (const auto &load_user : load_users) { // Pattern1 if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) { - DeleteLoadUserUpdateState(manager, load_user.first, load); + DeleteLoadUserUpdateState(manager, load_user.first); + change = true; continue; } + if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) { const auto &make_tuple = load_user.first->cast(); auto &maketuple_users = manager->node_users()[make_tuple]; @@ -183,14 +186,17 @@ void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, // Pattern2 if (make_tuple->size() == 3) { DeleteLoadUserMakeTuple(manager, make_tuple, load); + change = true; continue; } // Pattern3 if (make_tuple->size() > 3) { ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load); + change = true; } } } + return change; } bool ReplaceSameGroupLoad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, @@ -198,12 +204,13 @@ bool ReplaceSameGroupLoad(const FuncGraphManagerPtr &manager, const FuncGraphPtr if (group.size() <= 1) { return false; } + bool change = false; const auto &main = toposet[group[0]]; for (size_t i = 1; i < group.size(); i++) { - ReplaceLoadUser(manager, fg, toposet[group[i]]); + change = ReplaceLoadUser(manager, fg, toposet[group[i]]); manager->Replace(toposet[group[i]], main); } - return true; + return change; } AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) { @@ -229,6 +236,7 @@ bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vectormanager(); mgr->SetEdge(load_node, second_input_index, monad); + change = true; } - return true; + return change; } // Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... => @@ -253,7 +262,7 @@ bool AutoMonadEliminator::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manage std::vector toposet = TopoSort(fg->get_return()); std::vector need_replace_loads; std::vector> load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads); - const bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads); + bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads); if (update_state_replaced) { changed = true; } @@ -264,13 +273,12 @@ bool AutoMonadEliminator::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manage need_merge_loads.insert(need_merge_loads.end(), groups.begin(), groups.end()); } for (auto &group : need_merge_loads) { - const bool replaced = ReplaceSameGroupLoad(manager, fg, toposet, group); - if (!changed && replaced) { + bool replaced = ReplaceSameGroupLoad(manager, fg, toposet, group); + if (replaced) { changed = true; } } } - MS_LOG(DEBUG) << "changed: " << changed; return changed; } diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc index 400176d59c..d5e52f5691 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc @@ -201,7 +201,7 @@ AnfNodePtr EliminateMakeTupleWithDeadNode(const CNodePtr &update_state, const CN return new_update_state; } -// Return true if the function is only used by make_tuple. +// Return true if the func is only used by MakeTuple. bool OnlyMakeTupleUseFunc(const CNodePtr &make_tuple, const AnfNodePtr &func_node) { auto mgr = GetManager(make_tuple); if (mgr == nullptr) { @@ -222,18 +222,30 @@ bool OnlyMakeTupleUseFunc(const CNodePtr &make_tuple, const AnfNodePtr &func_nod // u2 = UpdateState(u1, t) // To: // t = make_tuple(input, Function) or t = make_tuple(Function, input) -// u2 = u1 +// u2 = UpdateState(u1, input) AnfNodePtr EliminateUpdateStateWithMakeTupleFunc(const CNodePtr &update_state, const CNodePtr &make_tuple) { if (make_tuple->size() != kMakeTupleSize) { return nullptr; } + + // Get the other node that is not FuncGraph. + AnfNodePtr input_node = nullptr; auto &first_input = make_tuple->input(kInputIndex); - if (IsValueNode(first_input) && OnlyMakeTupleUseFunc(make_tuple, first_input)) { - return update_state->input(1); - } auto &second_input = make_tuple->input(kAttachIndex); - if (IsValueNode(second_input) && OnlyMakeTupleUseFunc(make_tuple, second_input)) { - return update_state->input(1); + if (IsValueNode(first_input) && OnlyMakeTupleUseFunc(make_tuple, first_input)) { + input_node = make_tuple->input(kAttachIndex); + } else if (IsValueNode(second_input) && OnlyMakeTupleUseFunc(make_tuple, second_input)) { + input_node = make_tuple->input(kInputIndex); + } + + // Create the new UpdateState node with `node_input`, replace the old UpdateStateNode. + if (input_node != nullptr) { + auto update_state_op = NewValueNode(prim::kPrimUpdateState); + auto fg = update_state->func_graph(); + auto new_update_state = fg->NewCNode({update_state_op, update_state->input(kInputIndex), input_node}); + new_update_state->set_abstract(update_state->abstract()); + new_update_state->set_scope(update_state->scope()); + return new_update_state; } return nullptr; }