|
|
|
@@ -1,5 +1,5 @@ |
|
|
|
/** |
|
|
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd |
|
|
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
@@ -40,7 +40,8 @@ bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node, const std: |
|
|
|
} |
|
|
|
auto main_attrs = main_primitive->attrs(); |
|
|
|
auto node_attrs = node_primitive->attrs(); |
|
|
|
std::vector<std::string> exclude_attrs{"IsFeatureMapOutput", "IsFeatureMapInputList", "pri_format"}; |
|
|
|
std::vector<std::string> exclude_attrs{"IsFeatureMapOutput", "IsFeatureMapInputList", "pri_format", "input_names", |
|
|
|
"output_names"}; |
|
|
|
for (auto &attr : exclude_attrs) { |
|
|
|
auto main_attrs_iter = main_attrs.find(attr); |
|
|
|
if (main_attrs_iter != main_attrs.end()) { |
|
|
|
@@ -121,6 +122,14 @@ bool GraphKernelBackendCSE::CheckEqualCnodeInputs(const AnfNodePtr &main, const |
|
|
|
bool GraphKernelCSE::Run(const FuncGraphPtr &func_graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
auto graphkernel_backend_cse = std::make_shared<GraphKernelBackendCSE>(black_list_); |
|
|
|
return graphkernel_backend_cse->Cse(func_graph, func_graph->manager()); |
|
|
|
auto changed = graphkernel_backend_cse->Cse(func_graph, func_graph->manager()); |
|
|
|
auto nodes = TopoSort(func_graph->get_return()); |
|
|
|
for (auto node : nodes) { |
|
|
|
auto graph_kernel_fg = GetCNodeFuncGraph(node); |
|
|
|
if (graph_kernel_fg != nullptr) { |
|
|
|
changed = graphkernel_backend_cse->Cse(graph_kernel_fg, graph_kernel_fg->manager()) || changed; |
|
|
|
} |
|
|
|
} |
|
|
|
return changed; |
|
|
|
} |
|
|
|
} // namespace mindspore::graphkernel |