Browse Source

Return new UpdateState with input when match UpdateState(, MakeTuple(input,FuncGraph)).

Replace US with its monad input in DeleteLoadUserUpdateState().
pull/14324/head
Zhang Qinghua 4 years ago
parent
commit
25664a966c
2 changed files with 41 additions and 21 deletions
  1. +22
    -14
      mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc
  2. +19
    -7
      mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc

+ 22
- 14
mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc View File

@@ -110,11 +110,11 @@ std::vector<std::vector<size_t>> SplitGroup(const std::vector<AnfNodePtr> &topos
// u3 = UpdateState(u2, b)
//==>
// delete the UpdateState
void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user,
const AnfNodePtr &load) {
const auto &load_cnode = load->cast<CNodePtr>();
const auto &u = load_cnode->input(2);
manager->Replace(load_user, u);
void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user) {
const auto &update_state_cnode = load_user->cast<CNodePtr>();
constexpr size_t monad_index = 1;
const auto &monad = update_state_cnode->input(monad_index);
manager->Replace(load_user, monad);
}

// Pattern2======================================
@@ -164,14 +164,17 @@ void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGrap
manager->Replace(make_tuple, new_make_tuple);
}

void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) {
bool ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) {
bool change = false;
auto load_users = manager->node_users()[load];
for (const auto &load_user : load_users) {
// Pattern1
if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) {
DeleteLoadUserUpdateState(manager, load_user.first, load);
DeleteLoadUserUpdateState(manager, load_user.first);
change = true;
continue;
}

if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) {
const auto &make_tuple = load_user.first->cast<CNodePtr>();
auto &maketuple_users = manager->node_users()[make_tuple];
@@ -183,14 +186,17 @@ void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg,
// Pattern2
if (make_tuple->size() == 3) {
DeleteLoadUserMakeTuple(manager, make_tuple, load);
change = true;
continue;
}
// Pattern3
if (make_tuple->size() > 3) {
ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load);
change = true;
}
}
}
return change;
}

bool ReplaceSameGroupLoad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg,
@@ -198,12 +204,13 @@ bool ReplaceSameGroupLoad(const FuncGraphManagerPtr &manager, const FuncGraphPtr
if (group.size() <= 1) {
return false;
}
bool change = false;
const auto &main = toposet[group[0]];
for (size_t i = 1; i < group.size(); i++) {
ReplaceLoadUser(manager, fg, toposet[group[i]]);
change = ReplaceLoadUser(manager, fg, toposet[group[i]]);
manager->Replace(toposet[group[i]], main);
}
return true;
return change;
}

AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) {
@@ -229,6 +236,7 @@ bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNode
if (need_replace_loads.size() == 0) {
return false;
}
bool change = false;
constexpr size_t second_input_index = 2;
auto monad = GetFirstMonad(fg);
for (const auto &load_node : need_replace_loads) {
@@ -241,8 +249,9 @@ bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNode
}
auto mgr = fg->manager();
mgr->SetEdge(load_node, second_input_index, monad);
change = true;
}
return true;
return change;
}

// Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... =>
@@ -253,7 +262,7 @@ bool AutoMonadEliminator::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manage
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
std::vector<AnfNodePtr> need_replace_loads;
std::vector<std::vector<size_t>> load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads);
const bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads);
bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads);
if (update_state_replaced) {
changed = true;
}
@@ -264,13 +273,12 @@ bool AutoMonadEliminator::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manage
need_merge_loads.insert(need_merge_loads.end(), groups.begin(), groups.end());
}
for (auto &group : need_merge_loads) {
const bool replaced = ReplaceSameGroupLoad(manager, fg, toposet, group);
if (!changed && replaced) {
bool replaced = ReplaceSameGroupLoad(manager, fg, toposet, group);
if (replaced) {
changed = true;
}
}
}
MS_LOG(DEBUG) << "changed: " << changed;
return changed;
}



+ 19
- 7
mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc View File

@@ -201,7 +201,7 @@ AnfNodePtr EliminateMakeTupleWithDeadNode(const CNodePtr &update_state, const CN
return new_update_state;
}

// Return true if the function is only used by make_tuple.
// Return true if the func is only used by MakeTuple.
bool OnlyMakeTupleUseFunc(const CNodePtr &make_tuple, const AnfNodePtr &func_node) {
auto mgr = GetManager(make_tuple);
if (mgr == nullptr) {
@@ -222,18 +222,30 @@ bool OnlyMakeTupleUseFunc(const CNodePtr &make_tuple, const AnfNodePtr &func_nod
// u2 = UpdateState(u1, t)
// To:
// t = make_tuple(input, Function) or t = make_tuple(Function, input)
// u2 = u1
// u2 = UpdateState(u1, input)
AnfNodePtr EliminateUpdateStateWithMakeTupleFunc(const CNodePtr &update_state, const CNodePtr &make_tuple) {
if (make_tuple->size() != kMakeTupleSize) {
return nullptr;
}

// Get the other node that is not FuncGraph.
AnfNodePtr input_node = nullptr;
auto &first_input = make_tuple->input(kInputIndex);
if (IsValueNode<FuncGraph>(first_input) && OnlyMakeTupleUseFunc(make_tuple, first_input)) {
return update_state->input(1);
}
auto &second_input = make_tuple->input(kAttachIndex);
if (IsValueNode<FuncGraph>(second_input) && OnlyMakeTupleUseFunc(make_tuple, second_input)) {
return update_state->input(1);
if (IsValueNode<FuncGraph>(first_input) && OnlyMakeTupleUseFunc(make_tuple, first_input)) {
input_node = make_tuple->input(kAttachIndex);
} else if (IsValueNode<FuncGraph>(second_input) && OnlyMakeTupleUseFunc(make_tuple, second_input)) {
input_node = make_tuple->input(kInputIndex);
}

// Create the new UpdateState node with `node_input`, replace the old UpdateStateNode.
if (input_node != nullptr) {
auto update_state_op = NewValueNode(prim::kPrimUpdateState);
auto fg = update_state->func_graph();
auto new_update_state = fg->NewCNode({update_state_op, update_state->input(kInputIndex), input_node});
new_update_state->set_abstract(update_state->abstract());
new_update_state->set_scope(update_state->scope());
return new_update_state;
}
return nullptr;
}


Loading…
Cancel
Save