| @@ -16,6 +16,8 @@ | |||
| #include "backend/optimizer/pass/common_subexpression_elimination.h" | |||
| #include <memory> | |||
| #include "runtime/device/kernel_info.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "utils/flags.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -33,48 +35,60 @@ bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) { | |||
| } | |||
| return false; | |||
| } | |||
| bool HasSideEffectAttr(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (!AnfAlgo::HasNodeAttr(GRAPH_FLAG_SIDE_EFFECT, cnode)) { | |||
| return false; | |||
| } | |||
| return AnfAlgo::GetNodeAttr<bool>(cnode, GRAPH_FLAG_SIDE_EFFECT); | |||
| } | |||
| } // namespace | |||
| bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const { | |||
| bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const { | |||
| MS_EXCEPTION_IF_NULL(main); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| bool replace = false; | |||
| if (main->isa<ValueNode>() && node->isa<ValueNode>()) { | |||
| auto main_value = GetValueNode(main); | |||
| auto node_value = GetValueNode(node); | |||
| if (main_value->isa<Primitive>() && node_value->isa<Primitive>()) { | |||
| replace = false; | |||
| return false; | |||
| } else if (main_value->isa<tensor::Tensor>() && node_value->isa<tensor::Tensor>()) { | |||
| replace = (AbsOf(main) == AbsOf(node)) && CheckEqualKernelBuildInfo(main, node); | |||
| return (AbsOf(main) == AbsOf(node)) && CheckEqualKernelBuildInfo(main, node); | |||
| } else { | |||
| replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); | |||
| return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); | |||
| } | |||
| } else if (main->isa<CNode>() && node->isa<CNode>()) { | |||
| if (check_side_effect && HasSideEffectAttr(main)) { | |||
| return false; | |||
| } | |||
| if (!CheckEqualKernelBuildInfo(main, node)) { | |||
| replace = false; | |||
| } else { | |||
| auto c_main = main->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(c_main); | |||
| auto c_node = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(c_node); | |||
| const auto &inp1 = c_main->inputs(); | |||
| const auto &inp2 = c_node->inputs(); | |||
| if (inp1.size() == inp2.size()) { | |||
| bool appsame = true; | |||
| for (size_t j = 0; j < inp1.size(); j++) { | |||
| MS_EXCEPTION_IF_NULL(inp1[j]); | |||
| MS_EXCEPTION_IF_NULL(inp2[j]); | |||
| if (!(*inp1[j] == *inp2[j])) { | |||
| appsame = false; | |||
| break; | |||
| } | |||
| } | |||
| replace = appsame; | |||
| return false; | |||
| } | |||
| auto c_main = main->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(c_main); | |||
| auto c_node = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(c_node); | |||
| const auto &inp1 = c_main->inputs(); | |||
| const auto &inp2 = c_node->inputs(); | |||
| if (inp1.size() != inp2.size()) { | |||
| return false; | |||
| } | |||
| for (size_t j = 0; j < inp1.size(); j++) { | |||
| auto inp1_j = inp1[j]; | |||
| auto inp2_j = inp2[j]; | |||
| MS_EXCEPTION_IF_NULL(inp1_j); | |||
| MS_EXCEPTION_IF_NULL(inp2_j); | |||
| if (!(*inp1_j == *inp2_j)) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| return replace; | |||
| return false; | |||
| } | |||
| bool CommonSubexpressionElimination::Run(const FuncGraphPtr &func_graph) { | |||