From 4d6ff51ca17f46c21db802af7f58ff1112593d7c Mon Sep 17 00:00:00 2001 From: gongxiaoqing Date: Thu, 25 Feb 2021 22:28:37 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20!12569?= =?UTF-8?q?=20:=20Add=20circle=20check=20in=20ub=20fusion'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ascend/buffer_fusion/ub_pattern_fusion.cc | 22 ------------------- .../ccsrc/backend/optimizer/common/helper.cc | 21 ++++++++++++++++++ 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc index 111a937468..6457098627 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc @@ -27,7 +27,6 @@ #include "base/core_ops.h" #include "runtime/device/kernel_info.h" #include "utils/ms_context.h" -#include "backend/optimizer/common/helper.h" namespace mindspore { namespace opt { @@ -354,24 +353,6 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector *buffer_fusion_infos) { - MS_EXCEPTION_IF_NULL(buffer_fusion_infos); - for (auto &[fusion_id, fusion_info] : *buffer_fusion_infos) { - bool has_circle = false; - for (auto &inp : fusion_info.inputs_list) { - if (IsDepend(kernel_graph, inp, fusion_info.anf_nodes)) { - has_circle = true; - break; - } - } - - if (has_circle) { - buffer_fusion_infos->erase(fusion_id); - } - } -} } // namespace void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph, @@ -380,9 +361,6 @@ void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph, GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos); GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos); GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos); - // Remove circle which will produce a circle if do fusion - RemoveCircle(*kernel_graph, buffer_fusion_infos); - for (auto &buffer_fusion_info : *buffer_fusion_infos) { buffer_fusion_info.second.kernel_build_info = CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list); diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index 5414240d48..5c95bd2b76 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -49,6 +49,23 @@ std::vector Convert2Long(const std::vector &v) { bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector &nodes) { MS_EXCEPTION_IF_NULL(node); + std::vector node_list = TopoSort(graph.get_return()); + std::map> control_depend_map; + for (auto &nd : node_list) { + MS_EXCEPTION_IF_NULL(nd); + if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { + auto control_depend = nd->cast(); + auto prior_node = control_depend->input(kControlDependPriorIndex); + auto behind_node = control_depend->input(kControlDependBehindIndex); + auto it = control_depend_map.find(behind_node); + if (it == control_depend_map.end()) { + control_depend_map[behind_node] = std::set{prior_node}; + } else { + it->second.insert(prior_node); + } + } + } + FuncGraphManagerPtr manager = graph.manager(); MS_EXCEPTION_IF_NULL(manager); @@ -71,6 +88,10 @@ bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector< auto inputs = cnode->inputs(); (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); } + auto it = control_depend_map.find(nd); + if (it != control_depend_map.end()) { + (void)todo.insert(todo.end(), it->second.begin(), it->second.end()); + } } return false; }