|
|
|
@@ -110,11 +110,11 @@ std::vector<std::vector<size_t>> SplitGroup(const std::vector<AnfNodePtr> &topos |
|
|
|
// u3 = UpdateState(u2, b) |
|
|
|
//==> |
|
|
|
// delete the UpdateState |
|
|
|
void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user, |
|
|
|
const AnfNodePtr &load) { |
|
|
|
const auto &load_cnode = load->cast<CNodePtr>(); |
|
|
|
const auto &u = load_cnode->input(2); |
|
|
|
manager->Replace(load_user, u); |
|
|
|
void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user) { |
|
|
|
const auto &update_state_cnode = load_user->cast<CNodePtr>(); |
|
|
|
constexpr size_t monad_index = 1; |
|
|
|
const auto &monad = update_state_cnode->input(monad_index); |
|
|
|
manager->Replace(load_user, monad); |
|
|
|
} |
|
|
|
|
|
|
|
// Pattern2====================================== |
|
|
|
@@ -164,14 +164,17 @@ void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGrap |
|
|
|
manager->Replace(make_tuple, new_make_tuple); |
|
|
|
} |
|
|
|
|
|
|
|
void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) { |
|
|
|
bool ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) { |
|
|
|
bool change = false; |
|
|
|
auto load_users = manager->node_users()[load]; |
|
|
|
for (const auto &load_user : load_users) { |
|
|
|
// Pattern1 |
|
|
|
if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) { |
|
|
|
DeleteLoadUserUpdateState(manager, load_user.first, load); |
|
|
|
DeleteLoadUserUpdateState(manager, load_user.first); |
|
|
|
change = true; |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) { |
|
|
|
const auto &make_tuple = load_user.first->cast<CNodePtr>(); |
|
|
|
auto &maketuple_users = manager->node_users()[make_tuple]; |
|
|
|
@@ -183,14 +186,17 @@ void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, |
|
|
|
// Pattern2 |
|
|
|
if (make_tuple->size() == 3) { |
|
|
|
DeleteLoadUserMakeTuple(manager, make_tuple, load); |
|
|
|
change = true; |
|
|
|
continue; |
|
|
|
} |
|
|
|
// Pattern3 |
|
|
|
if (make_tuple->size() > 3) { |
|
|
|
ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load); |
|
|
|
change = true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return change; |
|
|
|
} |
|
|
|
|
|
|
|
bool ReplaceSameGroupLoad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, |
|
|
|
@@ -198,12 +204,13 @@ bool ReplaceSameGroupLoad(const FuncGraphManagerPtr &manager, const FuncGraphPtr |
|
|
|
if (group.size() <= 1) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
bool change = false; |
|
|
|
const auto &main = toposet[group[0]]; |
|
|
|
for (size_t i = 1; i < group.size(); i++) { |
|
|
|
ReplaceLoadUser(manager, fg, toposet[group[i]]); |
|
|
|
change = ReplaceLoadUser(manager, fg, toposet[group[i]]); |
|
|
|
manager->Replace(toposet[group[i]], main); |
|
|
|
} |
|
|
|
return true; |
|
|
|
return change; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) { |
|
|
|
@@ -229,6 +236,7 @@ bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNode |
|
|
|
if (need_replace_loads.size() == 0) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
bool change = false; |
|
|
|
constexpr size_t second_input_index = 2; |
|
|
|
auto monad = GetFirstMonad(fg); |
|
|
|
for (const auto &load_node : need_replace_loads) { |
|
|
|
@@ -241,8 +249,9 @@ bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNode |
|
|
|
} |
|
|
|
auto mgr = fg->manager(); |
|
|
|
mgr->SetEdge(load_node, second_input_index, monad); |
|
|
|
change = true; |
|
|
|
} |
|
|
|
return true; |
|
|
|
return change; |
|
|
|
} |
|
|
|
|
|
|
|
// Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... => |
|
|
|
@@ -253,7 +262,7 @@ bool AutoMonadEliminator::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manage |
|
|
|
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return()); |
|
|
|
std::vector<AnfNodePtr> need_replace_loads; |
|
|
|
std::vector<std::vector<size_t>> load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads); |
|
|
|
const bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads); |
|
|
|
bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads); |
|
|
|
if (update_state_replaced) { |
|
|
|
changed = true; |
|
|
|
} |
|
|
|
@@ -264,13 +273,12 @@ bool AutoMonadEliminator::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manage |
|
|
|
need_merge_loads.insert(need_merge_loads.end(), groups.begin(), groups.end()); |
|
|
|
} |
|
|
|
for (auto &group : need_merge_loads) { |
|
|
|
const bool replaced = ReplaceSameGroupLoad(manager, fg, toposet, group); |
|
|
|
if (!changed && replaced) { |
|
|
|
bool replaced = ReplaceSameGroupLoad(manager, fg, toposet, group); |
|
|
|
if (replaced) { |
|
|
|
changed = true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "changed: " << changed; |
|
|
|
return changed; |
|
|
|
} |
|
|
|
|
|
|
|
|