|
|
|
@@ -66,36 +66,23 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNode |
|
|
|
return EXCLUDE; |
|
|
|
} |
|
|
|
|
|
|
|
// The GetItem node should be fused with its real input and users. |
|
|
|
// The GetItem node should be fused with its real input. |
|
|
|
// If its real input is not in the fuse_list, the GetItem should be excluded. |
|
|
|
AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) { |
|
|
|
if (fused_op.empty()) return AnfNodePtrList(); |
|
|
|
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end()); |
|
|
|
auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; }; |
|
|
|
|
|
|
|
auto mng = fused_op[0]->func_graph()->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(mng); |
|
|
|
bool changed = true; |
|
|
|
while (changed) { |
|
|
|
changed = false; |
|
|
|
AnfNodePtrList remove_list; |
|
|
|
for (auto getitem : fused_op_set) { |
|
|
|
if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue; |
|
|
|
|
|
|
|
// GetItem should be fused with its real input. |
|
|
|
auto prev_node = getitem->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem); |
|
|
|
if (check_include(prev_node) == EXCLUDE) { |
|
|
|
remove_list.push_back(getitem); |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
// GetItem should be fused with its all users. |
|
|
|
const auto &users = mng->node_users()[getitem]; |
|
|
|
if (std::any_of(users.begin(), users.end(), [check_include](const std::pair<AnfNodePtr, int> &user) { |
|
|
|
return check_include(user.first) == EXCLUDE; |
|
|
|
})) { |
|
|
|
remove_list = DeepLinkedGraphSearch(getitem, check_include); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!remove_list.empty()) { |
|
|
|
|