diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc index e8e2295f6e..cb74a3e645 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc @@ -140,7 +140,7 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vectorcast(); - if (node == nullptr || fused_ops->count(node)) { + if (node == nullptr || IsKeepBasicNode(node) || fused_ops->count(node)) { continue; } bool is_fusible_op = IsFusibleOp(node); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index 3a6b949f80..eb50aa8884 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -147,7 +147,8 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(mng); for (const auto &n : todos) { auto node = n->cast(); - if (node == nullptr || !AnfAlgo::IsRealKernel(node) || AnfAlgo::IsGraphKernel(node) || !CanExpand(node)) { + if (node == nullptr || IsKeepBasicNode(node) || !AnfAlgo::IsRealKernel(node) || AnfAlgo::IsGraphKernel(node) || + !CanExpand(node)) { continue; } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index f2fa804431..b6e7e06ed2 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -773,5 +773,30 @@ void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfN // Set attr secondly. AnfAlgo::SetNodeAttr(key, value, node); } + +bool IsKeepBasicNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + + static std::vector contagious_attrs = {"inplace_group", "inplace_algo", "inplace_output_index", + "aggregate", "aggregate_input_indexx"}; + static std::vector> attrs_with_value = { + [](const AnfNodePtr &n) -> bool { return AnfAlgo::GetBooleanAttr(n, "skip"); }}; + + // If node contain attribute in contagious_attrs, it have to keep basic no matter what the value is. + // If node contain attribute in attrs_with_value, it only have to keep basic when the check result is true. + if (std::any_of(contagious_attrs.cbegin(), contagious_attrs.cend(), + [&cnode](const std::string &attr_name) -> bool { return AnfAlgo::HasNodeAttr(attr_name, cnode); }) || + std::any_of(attrs_with_value.cbegin(), attrs_with_value.cend(), + [&cnode](std::function func) -> bool { return func(cnode); })) { + return true; + } + + return false; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h index b36460971c..572dcabc44 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h @@ -89,6 +89,7 @@ std::vector GetReduceAxis(const AnfNodePtr &node); CNodePtr CreateCNode(const std::vector &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info); void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node); +bool IsKeepBasicNode(const AnfNodePtr &node); template ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) {