diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc index 2e8cfc283b..a134c93346 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,7 +40,8 @@ bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node, const std: } auto main_attrs = main_primitive->attrs(); auto node_attrs = node_primitive->attrs(); - std::vector exclude_attrs{"IsFeatureMapOutput", "IsFeatureMapInputList", "pri_format"}; + std::vector exclude_attrs{"IsFeatureMapOutput", "IsFeatureMapInputList", "pri_format", "input_names", + "output_names"}; for (auto &attr : exclude_attrs) { auto main_attrs_iter = main_attrs.find(attr); if (main_attrs_iter != main_attrs.end()) { @@ -121,6 +122,14 @@ bool GraphKernelBackendCSE::CheckEqualCnodeInputs(const AnfNodePtr &main, const bool GraphKernelCSE::Run(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); auto graphkernel_backend_cse = std::make_shared(black_list_); - return graphkernel_backend_cse->Cse(func_graph, func_graph->manager()); + auto changed = graphkernel_backend_cse->Cse(func_graph, func_graph->manager()); + auto nodes = TopoSort(func_graph->get_return()); + for (auto node : nodes) { + auto graph_kernel_fg = GetCNodeFuncGraph(node); + if (graph_kernel_fg != nullptr) { + changed = graphkernel_backend_cse->Cse(graph_kernel_fg, graph_kernel_fg->manager()) || changed; + } + } + return changed; } } // namespace mindspore::graphkernel