diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 9188231a26..92f9fea933 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -51,6 +51,7 @@ #include "pre_activate/ascend/ir_fusion/derelu_fusion.h" #include "pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h" #include "pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" +#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h" #include "pre_activate/ascend/format_type/insert_trans_op.h" #include "pre_activate/pass/getitem_tuple.h" #include "pre_activate/pass/optimize_dependence.h" @@ -100,6 +101,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc index caea9599c1..78e6856d5a 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc @@ -73,13 +73,16 @@ AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const An return mul0; } -bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &reduce_sum) { - MS_EXCEPTION_IF_NULL(graph); +bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &mul1_anf, + const AnfNodePtr &reduce_sum) { MS_EXCEPTION_IF_NULL(mul0_anf); + MS_EXCEPTION_IF_NULL(mul1_anf); MS_EXCEPTION_IF_NULL(reduce_sum); - if (!mul0_anf->isa()) { + if (!mul0_anf->isa() || !mul1_anf->isa()) { return true; } + auto mul1 = mul1_anf->cast(); + MS_EXCEPTION_IF_NULL(mul1); auto mul0 = mul0_anf->cast(); MS_EXCEPTION_IF_NULL(mul0); @@ -88,20 +91,14 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf return true; } - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(reduce_sum) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "node has no output in manager"; + if (IsDepend(graph, mul0->input(1), reduce_sum)) { + MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion"; + return true; } - const AnfNodeIndexSet &outputs_set = manager->node_users()[reduce_sum]; - auto it = std::find_if(outputs_set.begin(), outputs_set.end(), [&mul0](const std::pair &node_index) { - return node_index.first == mul0->input(1) || node_index.first == mul0; - }); - if (it != outputs_set.end()) { - MS_LOG(INFO) << "ReduceSum's output node is mul0's input or mul0! If do fusion, graph will exist a circle"; + if (IsDepend(graph, mul1->input(1), mul0)) { + MS_LOG(INFO) << "mul1->input(1) depends on mul0, quit fusion"; return true; } - return false; } } // namespace @@ -131,7 +128,7 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons MS_LOG(INFO) << "Mul0 do not exist, quit fusion"; return nullptr; } - if (QuitFusion(graph, mul0, node)) { + if (QuitFusion(graph, mul0, mul1, node)) { return nullptr; } diff --git a/mindspore/ccsrc/pre_activate/common/helper.cc b/mindspore/ccsrc/pre_activate/common/helper.cc index 5cc3374ea5..dfb6a32dde 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.cc +++ b/mindspore/ccsrc/pre_activate/common/helper.cc @@ -18,6 +18,9 @@ #include #include #include +#include +#include +#include #include "utils/utils.h" #include "utils/base_ref.h" #include "session/anf_runtime_algorithm.h" @@ -35,6 +38,56 @@ std::vector Convert2Int(const std::vector &v) { return result; } +bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node1); + MS_EXCEPTION_IF_NULL(node2); + std::vector node_list = TopoSort(graph->get_return()); + std::map> control_depend_map; + for (auto &nd : node_list) { + if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { + auto control_depend = nd->cast(); + auto prior_node = control_depend->input(kControlDependPriorIndex); + auto behind_node = control_depend->input(kControlDependBehindIndex); + auto it = control_depend_map.find(behind_node); + if (it == control_depend_map.end()) { + control_depend_map[behind_node] = std::set{prior_node}; + } else { + it->second.insert(prior_node); + } + } + } + + FuncGraphManagerPtr manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + std::unordered_set seen_node; + std::deque todo{node1}; + while (!todo.empty()) { + AnfNodePtr node = todo.front(); + todo.pop_front(); + if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) { + continue; + } + (void)seen_node.insert(node); + + if (node == node2) { + return true; + } + if (node->isa()) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto inputs = cnode->inputs(); + (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); + } + auto it = control_depend_map.find(node); + if (it != control_depend_map.end()) { + (void)todo.insert(todo.end(), it->second.begin(), it->second.end()); + } + } + return false; +} + bool UnVisited(const BaseRef &n) { if (utils::isa(n)) { AnfNodePtr in = utils::cast(n); diff --git a/mindspore/ccsrc/pre_activate/common/helper.h b/mindspore/ccsrc/pre_activate/common/helper.h index 1e27db132e..f244baa4a1 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.h +++ b/mindspore/ccsrc/pre_activate/common/helper.h @@ -111,6 +111,9 @@ enum ConvBn1Output { std::vector Convert2Int(const std::vector &v); +// check whether node1 depends on node2 or not +bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2); + bool UnVisited(const BaseRef &n); bool Visited(const BaseRef &n);