|
|
|
@@ -36,6 +36,7 @@ |
|
|
|
#include "debug/anf_ir_dump.h" |
|
|
|
#include "ir/func_graph_cloner.h" |
|
|
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" |
|
|
|
#include "backend/optimizer/pass/getitem_tuple.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
@@ -60,20 +61,20 @@ std::vector<AnfNodePtr> DeepUsersSearch(const std::vector<AnfNodePtr> &roots, co |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
bool IsFuse(const AnfNodePtr &node) { |
|
|
|
// composite fuse composite op |
|
|
|
if (AnfAlgo::IsGraphKernel(node)) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
return IsBasicFuseOp(node); |
|
|
|
} |
|
|
|
|
|
|
|
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) { |
|
|
|
if (cur_node == node) { |
|
|
|
return FOLLOW; |
|
|
|
} |
|
|
|
bool is_fusable = IsFuse(node); |
|
|
|
return is_fusable ? FOLLOW : EXCLUDE; |
|
|
|
if (AnfAlgo::IsGraphKernel(node) || IsBasicFuseOp(node)) { |
|
|
|
return FOLLOW; |
|
|
|
} |
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { |
|
|
|
auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem); |
|
|
|
if (AnfAlgo::IsGraphKernel(prev_node)) { |
|
|
|
return FOLLOW; |
|
|
|
} |
|
|
|
} |
|
|
|
return EXCLUDE; |
|
|
|
} |
|
|
|
|
|
|
|
IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const AnfNodePtr &node) { |
|
|
|
@@ -185,6 +186,60 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo |
|
|
|
return res; |
|
|
|
} |
|
|
|
|
|
|
|
// The GetItem node should be fused with its real input and users. |
|
|
|
// If its real input is not in the fuse_list, the GetItem should be excluded. |
|
|
|
AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) { |
|
|
|
if (fused_op.empty()) return AnfNodePtrList(); |
|
|
|
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end()); |
|
|
|
auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; }; |
|
|
|
|
|
|
|
auto mng = fused_op[0]->func_graph()->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(mng); |
|
|
|
bool changed = true; |
|
|
|
while (changed) { |
|
|
|
changed = false; |
|
|
|
AnfNodePtrList remove_list; |
|
|
|
for (auto node : fused_op_set) { |
|
|
|
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) continue; |
|
|
|
// GetItem should be fused with its real input. |
|
|
|
auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem); |
|
|
|
if (check_include(prev_node) == EXCLUDE) { |
|
|
|
remove_list.push_back(node); |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
// GetItem should be fused with its all users. |
|
|
|
auto &users = mng->node_users()[node]; |
|
|
|
bool outside_user_found = false; |
|
|
|
for (auto iter = users.begin(); iter != users.end(); ++iter) { |
|
|
|
if (check_include(iter->first) == EXCLUDE) { |
|
|
|
outside_user_found = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (outside_user_found) { |
|
|
|
remove_list = DeepUsersSearch(node, check_include, mng); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!remove_list.empty()) { |
|
|
|
for (auto node : remove_list) { |
|
|
|
fused_op_set.erase(node); |
|
|
|
} |
|
|
|
changed = true; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// keep the original order of fused_op. |
|
|
|
AnfNodePtrList result; |
|
|
|
for (auto node : fused_op) { |
|
|
|
if (fused_op_set.count(node)) { |
|
|
|
result.push_back(node); |
|
|
|
} |
|
|
|
} |
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) { |
|
|
|
if (lst->size() < 2) { |
|
|
|
return; |
|
|
|
@@ -254,6 +309,7 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) { |
|
|
|
if (used_nodes.size() > 1) { |
|
|
|
used_nodes = RemoveCircle(used_nodes); |
|
|
|
} |
|
|
|
used_nodes = RemoveWildGetitem(used_nodes); |
|
|
|
TopoSortForNodeList(&used_nodes); |
|
|
|
return used_nodes; |
|
|
|
} |
|
|
|
@@ -288,8 +344,22 @@ bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph) |
|
|
|
return changed; |
|
|
|
} |
|
|
|
|
|
|
|
void EliminateGetItem(const FuncGraphPtr &func_graph) { |
|
|
|
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>(); |
|
|
|
auto todos = TopoSort(func_graph->get_return()); |
|
|
|
for (auto node : todos) { |
|
|
|
if (AnfAlgo::IsGraphKernel(node)) { |
|
|
|
eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(node)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool CompositeOpsFusion::Run(const FuncGraphPtr &func_graph) { |
|
|
|
return FuseCompositeOps(std::dynamic_pointer_cast<session::KernelGraph>(func_graph)); |
|
|
|
auto changed = FuseCompositeOps(std::dynamic_pointer_cast<session::KernelGraph>(func_graph)); |
|
|
|
if (changed) { |
|
|
|
EliminateGetItem(func_graph); |
|
|
|
} |
|
|
|
return changed; |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |