| @@ -51,6 +51,7 @@ | |||||
| #include "pre_activate/ascend/ir_fusion/derelu_fusion.h" | #include "pre_activate/ascend/ir_fusion/derelu_fusion.h" | ||||
| #include "pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h" | #include "pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h" | ||||
| #include "pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" | #include "pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" | ||||
| #include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h" | |||||
| #include "pre_activate/ascend/format_type/insert_trans_op.h" | #include "pre_activate/ascend/format_type/insert_trans_op.h" | ||||
| #include "pre_activate/pass/getitem_tuple.h" | #include "pre_activate/pass/getitem_tuple.h" | ||||
| #include "pre_activate/pass/optimize_dependence.h" | #include "pre_activate/pass/optimize_dependence.h" | ||||
| @@ -100,6 +101,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||||
| ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>()); | ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); | ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<DereluFusion>()); | ir_fusion_pm->AddPass(std::make_shared<DereluFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<ConfusionMulGradFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); | ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>()); | ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>()); | ||||
| @@ -73,13 +73,16 @@ AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const An | |||||
| return mul0; | return mul0; | ||||
| } | } | ||||
| bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &reduce_sum) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &mul1_anf, | |||||
| const AnfNodePtr &reduce_sum) { | |||||
| MS_EXCEPTION_IF_NULL(mul0_anf); | MS_EXCEPTION_IF_NULL(mul0_anf); | ||||
| MS_EXCEPTION_IF_NULL(mul1_anf); | |||||
| MS_EXCEPTION_IF_NULL(reduce_sum); | MS_EXCEPTION_IF_NULL(reduce_sum); | ||||
| if (!mul0_anf->isa<CNode>()) { | |||||
| if (!mul0_anf->isa<CNode>() || !mul1_anf->isa<CNode>()) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| auto mul1 = mul1_anf->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(mul1); | |||||
| auto mul0 = mul0_anf->cast<CNodePtr>(); | auto mul0 = mul0_anf->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(mul0); | MS_EXCEPTION_IF_NULL(mul0); | ||||
| @@ -88,20 +91,14 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf | |||||
| return true; | return true; | ||||
| } | } | ||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| if (manager->node_users().find(reduce_sum) == manager->node_users().end()) { | |||||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | |||||
| if (IsDepend(graph, mul0->input(1), reduce_sum)) { | |||||
| MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion"; | |||||
| return true; | |||||
| } | } | ||||
| const AnfNodeIndexSet &outputs_set = manager->node_users()[reduce_sum]; | |||||
| auto it = std::find_if(outputs_set.begin(), outputs_set.end(), [&mul0](const std::pair<AnfNodePtr, int> &node_index) { | |||||
| return node_index.first == mul0->input(1) || node_index.first == mul0; | |||||
| }); | |||||
| if (it != outputs_set.end()) { | |||||
| MS_LOG(INFO) << "ReduceSum's output node is mul0's input or mul0! If do fusion, graph will exist a circle"; | |||||
| if (IsDepend(graph, mul1->input(1), mul0)) { | |||||
| MS_LOG(INFO) << "mul1->input(1) depends on mul0, quit fusion"; | |||||
| return true; | return true; | ||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -131,7 +128,7 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons | |||||
| MS_LOG(INFO) << "Mul0 do not exist, quit fusion"; | MS_LOG(INFO) << "Mul0 do not exist, quit fusion"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (QuitFusion(graph, mul0, node)) { | |||||
| if (QuitFusion(graph, mul0, mul1, node)) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -18,6 +18,9 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <map> | |||||
| #include <set> | |||||
| #include <deque> | |||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "utils/base_ref.h" | #include "utils/base_ref.h" | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| @@ -35,6 +38,56 @@ std::vector<int> Convert2Int(const std::vector<size_t> &v) { | |||||
| return result; | return result; | ||||
| } | } | ||||
| bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node1); | |||||
| MS_EXCEPTION_IF_NULL(node2); | |||||
| std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return()); | |||||
| std::map<AnfNodePtr, std::set<AnfNodePtr>> control_depend_map; | |||||
| for (auto &nd : node_list) { | |||||
| if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { | |||||
| auto control_depend = nd->cast<CNodePtr>(); | |||||
| auto prior_node = control_depend->input(kControlDependPriorIndex); | |||||
| auto behind_node = control_depend->input(kControlDependBehindIndex); | |||||
| auto it = control_depend_map.find(behind_node); | |||||
| if (it == control_depend_map.end()) { | |||||
| control_depend_map[behind_node] = std::set<AnfNodePtr>{prior_node}; | |||||
| } else { | |||||
| it->second.insert(prior_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| FuncGraphManagerPtr manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| std::unordered_set<AnfNodePtr> seen_node; | |||||
| std::deque<AnfNodePtr> todo{node1}; | |||||
| while (!todo.empty()) { | |||||
| AnfNodePtr node = todo.front(); | |||||
| todo.pop_front(); | |||||
| if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) { | |||||
| continue; | |||||
| } | |||||
| (void)seen_node.insert(node); | |||||
| if (node == node2) { | |||||
| return true; | |||||
| } | |||||
| if (node->isa<CNode>()) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto inputs = cnode->inputs(); | |||||
| (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); | |||||
| } | |||||
| auto it = control_depend_map.find(node); | |||||
| if (it != control_depend_map.end()) { | |||||
| (void)todo.insert(todo.end(), it->second.begin(), it->second.end()); | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool UnVisited(const BaseRef &n) { | bool UnVisited(const BaseRef &n) { | ||||
| if (utils::isa<AnfNodePtr>(n)) { | if (utils::isa<AnfNodePtr>(n)) { | ||||
| AnfNodePtr in = utils::cast<AnfNodePtr>(n); | AnfNodePtr in = utils::cast<AnfNodePtr>(n); | ||||
| @@ -111,6 +111,9 @@ enum ConvBn1Output { | |||||
| std::vector<int> Convert2Int(const std::vector<size_t> &v); | std::vector<int> Convert2Int(const std::vector<size_t> &v); | ||||
| // check whether node1 depends on node2 or not | |||||
| bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2); | |||||
| bool UnVisited(const BaseRef &n); | bool UnVisited(const BaseRef &n); | ||||
| bool Visited(const BaseRef &n); | bool Visited(const BaseRef &n); | ||||