| @@ -24,6 +24,7 @@ | |||||
| #include "base/core_ops.h" | #include "base/core_ops.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "backend/optimizer/common/fusion_id_allocator.h" | #include "backend/optimizer/common/fusion_id_allocator.h" | ||||
| #include "backend/optimizer/common/helper.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -59,6 +60,10 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod | |||||
| auto bnupdate = getitem->input(1); | auto bnupdate = getitem->input(1); | ||||
| MS_EXCEPTION_IF_NULL(bnupdate); | MS_EXCEPTION_IF_NULL(bnupdate); | ||||
| if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { | if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { | ||||
| if (cnode->size() == ELTWISE_DOUBLE_IN_INPUT_SIZE && | |||||
| IsDepend(kernel_graph, cnode->input(2), {relu_input, bnupdate})) { | |||||
| return; | |||||
| } | |||||
| std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); | std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); | ||||
| for (auto out_getitem : manager->node_users()[bnupdate]) { | for (auto out_getitem : manager->node_users()[bnupdate]) { | ||||
| MS_EXCEPTION_IF_NULL(out_getitem.first); | MS_EXCEPTION_IF_NULL(out_getitem.first); | ||||
| @@ -97,11 +97,11 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf | |||||
| auto mul0 = mul0_anf->cast<CNodePtr>(); | auto mul0 = mul0_anf->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(mul0); | MS_EXCEPTION_IF_NULL(mul0); | ||||
| if (IsDepend(graph, mul0->input(1), reduce_sum)) { | |||||
| if (IsDepend(*graph, mul0->input(1), {reduce_sum})) { | |||||
| MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion"; | MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion"; | ||||
| return true; | return true; | ||||
| } | } | ||||
| if (IsDepend(graph, mul1->input(1), mul0)) { | |||||
| if (IsDepend(*graph, mul1->input(1), {mul0})) { | |||||
| MS_LOG(INFO) << "mul1->input(1) depends on mul0, quit fusion"; | MS_LOG(INFO) << "mul1->input(1) depends on mul0, quit fusion"; | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -39,11 +39,9 @@ std::vector<int> Convert2Int(const std::vector<size_t> &v) { | |||||
| return result; | return result; | ||||
| } | } | ||||
| bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node1); | |||||
| MS_EXCEPTION_IF_NULL(node2); | |||||
| std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return()); | |||||
| bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| std::vector<AnfNodePtr> node_list = TopoSort(graph.get_return()); | |||||
| std::map<AnfNodePtr, std::set<AnfNodePtr>> control_depend_map; | std::map<AnfNodePtr, std::set<AnfNodePtr>> control_depend_map; | ||||
| for (auto &nd : node_list) { | for (auto &nd : node_list) { | ||||
| MS_EXCEPTION_IF_NULL(nd); | MS_EXCEPTION_IF_NULL(nd); | ||||
| @@ -60,29 +58,29 @@ bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodeP | |||||
| } | } | ||||
| } | } | ||||
| FuncGraphManagerPtr manager = graph->manager(); | |||||
| FuncGraphManagerPtr manager = graph.manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| std::unordered_set<AnfNodePtr> seen_node; | std::unordered_set<AnfNodePtr> seen_node; | ||||
| std::deque<AnfNodePtr> todo{node1}; | |||||
| std::deque<AnfNodePtr> todo{node}; | |||||
| while (!todo.empty()) { | while (!todo.empty()) { | ||||
| AnfNodePtr node = todo.front(); | |||||
| AnfNodePtr nd = todo.front(); | |||||
| todo.pop_front(); | todo.pop_front(); | ||||
| if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) { | |||||
| if (seen_node.count(nd) > 0 || !manager->all_nodes().contains(nd)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| (void)seen_node.insert(node); | |||||
| (void)seen_node.insert(nd); | |||||
| if (node == node2) { | |||||
| if (std::any_of(nodes.begin(), nodes.end(), [&nd](const AnfNodePtr &item) { return nd == item; })) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| if (node->isa<CNode>()) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (nd->isa<CNode>()) { | |||||
| auto cnode = nd->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| auto inputs = cnode->inputs(); | auto inputs = cnode->inputs(); | ||||
| (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); | (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); | ||||
| } | } | ||||
| auto it = control_depend_map.find(node); | |||||
| auto it = control_depend_map.find(nd); | |||||
| if (it != control_depend_map.end()) { | if (it != control_depend_map.end()) { | ||||
| (void)todo.insert(todo.end(), it->second.begin(), it->second.end()); | (void)todo.insert(todo.end(), it->second.begin(), it->second.end()); | ||||
| } | } | ||||
| @@ -119,8 +119,8 @@ enum ConvBn1Output { | |||||
| std::vector<int> Convert2Int(const std::vector<size_t> &v); | std::vector<int> Convert2Int(const std::vector<size_t> &v); | ||||
| // check whether node1 depends on node2 or not | |||||
| bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2); | |||||
| // check whether node depends on either of nodes or not | |||||
| bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes); | |||||
| bool UnVisited(const BaseRef &n); | bool UnVisited(const BaseRef &n); | ||||