|
|
|
@@ -107,20 +107,6 @@ std::vector<std::vector<size_t>> SplitGroup(const std::vector<AnfNodePtr> &topos |
|
|
|
// a = Load(para1, u1) |
|
|
|
// ... |
|
|
|
// b = Load(para1, u2) |
|
|
|
// u3 = UpdateState(u2, b) |
|
|
|
//==> |
|
|
|
// delete the UpdateState |
|
|
|
void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user, |
|
|
|
const AnfNodePtr &load) { |
|
|
|
const auto &load_cnode = load->cast<CNodePtr>(); |
|
|
|
const auto &u = load_cnode->input(2); |
|
|
|
manager->Replace(load_user, u); |
|
|
|
} |
|
|
|
|
|
|
|
// Pattern2====================================== |
|
|
|
// a = Load(para1, u1) |
|
|
|
// ... |
|
|
|
// b = Load(para1, u2) |
|
|
|
// t = make_tuple(x, b) |
|
|
|
// u3 = UpdateState(u2, t) |
|
|
|
//==> |
|
|
|
@@ -141,7 +127,7 @@ void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr |
|
|
|
manager->Replace(make_tuple, other_input); |
|
|
|
} |
|
|
|
|
|
|
|
// Pattern3====================================== |
|
|
|
// Pattern2====================================== |
|
|
|
// a = Load(para1, u1) |
|
|
|
// ... |
|
|
|
// b = Load(para1, u2) |
|
|
|
@@ -167,11 +153,6 @@ void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGrap |
|
|
|
void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) { |
|
|
|
auto load_users = manager->node_users()[load]; |
|
|
|
for (const auto &load_user : load_users) { |
|
|
|
// Pattern1 |
|
|
|
if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) { |
|
|
|
DeleteLoadUserUpdateState(manager, load_user.first, load); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) { |
|
|
|
const auto &make_tuple = load_user.first->cast<CNodePtr>(); |
|
|
|
auto &maketuple_users = manager->node_users()[make_tuple]; |
|
|
|
@@ -180,12 +161,12 @@ void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, |
|
|
|
if (!maketuple_as_input_of_update) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
// Pattern2 |
|
|
|
// Pattern1 |
|
|
|
if (make_tuple->size() == 3) { |
|
|
|
DeleteLoadUserMakeTuple(manager, make_tuple, load); |
|
|
|
continue; |
|
|
|
} |
|
|
|
// Pattern3 |
|
|
|
// Pattern2 |
|
|
|
if (make_tuple->size() > 3) { |
|
|
|
ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load); |
|
|
|
} |
|
|
|
|