Merge pull request !2047 from huanghui/code-reviewtags/v0.5.0-beta
| @@ -346,7 +346,8 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap | |||||
| save_graphs_path = "."; | save_graphs_path = "."; | ||||
| } | } | ||||
| if (save_graphs) { | if (save_graphs) { | ||||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_ub_fusion_before.ir"; | |||||
| std::string file_path = | |||||
| save_graphs_path + "/hwopt_d_ub_fusion_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_path, kernel_graph); | DumpIR(file_path, kernel_graph); | ||||
| } | } | ||||
| auto fusion_id_allocator = std::make_shared<FusionIdAllocator>(); | auto fusion_id_allocator = std::make_shared<FusionIdAllocator>(); | ||||
| @@ -372,7 +373,8 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap | |||||
| (void)optimizer->Optimize(kernel_graph); | (void)optimizer->Optimize(kernel_graph); | ||||
| kernel_graph->SetExecOrderByDefault(); | kernel_graph->SetExecOrderByDefault(); | ||||
| if (save_graphs) { | if (save_graphs) { | ||||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_ub_fusion_after.ir"; | |||||
| std::string file_path = | |||||
| save_graphs_path + "/hwopt_d_ub_fusion_after_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_path, kernel_graph); | DumpIR(file_path, kernel_graph); | ||||
| } | } | ||||
| } | } | ||||
| @@ -34,16 +34,22 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod | |||||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | MS_EXCEPTION_IF_NULL(candidate_fusion); | ||||
| auto manager = kernel_graph.manager(); | auto manager = kernel_graph.manager(); | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| MS_EXCEPTION_IF_NULL(relu_input); | |||||
| auto add = relu_input->cast<CNodePtr>(); | auto add = relu_input->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(add); | MS_EXCEPTION_IF_NULL(add); | ||||
| auto tuple_getitem = add->input(1); | auto tuple_getitem = add->input(1); | ||||
| MS_EXCEPTION_IF_NULL(tuple_getitem); | |||||
| if (tuple_getitem->isa<CNode>() && AnfAlgo::GetCNodeName(tuple_getitem) == prim::kPrimTupleGetItem->name()) { | if (tuple_getitem->isa<CNode>() && AnfAlgo::GetCNodeName(tuple_getitem) == prim::kPrimTupleGetItem->name()) { | ||||
| auto getitem = tuple_getitem->cast<CNodePtr>(); | auto getitem = tuple_getitem->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(getitem); | |||||
| auto bnupdate = getitem->input(1); | auto bnupdate = getitem->input(1); | ||||
| MS_EXCEPTION_IF_NULL(bnupdate); | |||||
| if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { | if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { | ||||
| std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); | std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); | ||||
| for (auto out_getitem : manager->node_users()[bnupdate]) { | for (auto out_getitem : manager->node_users()[bnupdate]) { | ||||
| MS_EXCEPTION_IF_NULL(out_getitem.first); | |||||
| auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>(); | auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(out_getitem_ptr); | |||||
| auto input2 = out_getitem_ptr->input(2); | auto input2 = out_getitem_ptr->input(2); | ||||
| auto output_idx = GetValue<int>(GetValueNode(input2)); | auto output_idx = GetValue<int>(GetValueNode(input2)); | ||||
| output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); | output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); | ||||
| @@ -34,12 +34,17 @@ void BnupdateEltwiseFusionPass::MatchBnupdateRelu(const CNodePtr &cnode, const A | |||||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | MS_EXCEPTION_IF_NULL(candidate_fusion); | ||||
| auto manager = kernel_graph.manager(); | auto manager = kernel_graph.manager(); | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| MS_EXCEPTION_IF_NULL(relu_input); | |||||
| auto getitem = relu_input->cast<CNodePtr>(); | auto getitem = relu_input->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(getitem); | |||||
| auto bnupdate = getitem->input(1); | auto bnupdate = getitem->input(1); | ||||
| MS_EXCEPTION_IF_NULL(bnupdate); | |||||
| if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { | if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { | ||||
| std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); | std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); | ||||
| for (auto out_getitem : manager->node_users()[bnupdate]) { | for (auto out_getitem : manager->node_users()[bnupdate]) { | ||||
| MS_EXCEPTION_IF_NULL(out_getitem.first); | |||||
| auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>(); | auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(out_getitem_ptr); | |||||
| auto input2 = out_getitem_ptr->input(2); | auto input2 = out_getitem_ptr->input(2); | ||||
| auto output_idx = GetValue<int>(GetValueNode(input2)); | auto output_idx = GetValue<int>(GetValueNode(input2)); | ||||
| output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); | output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); | ||||
| @@ -35,6 +35,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| std::unordered_set<AnfNodePtr> record{cnode}; | std::unordered_set<AnfNodePtr> record{cnode}; | ||||
| auto eltwise_input = cnode->input(1); | auto eltwise_input = cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(eltwise_input); | |||||
| if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { | if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { | ||||
| (void)record.insert(eltwise_input); | (void)record.insert(eltwise_input); | ||||
| } else { | } else { | ||||
| @@ -43,6 +44,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw | |||||
| auto input_cnode = eltwise_input->cast<CNodePtr>(); | auto input_cnode = eltwise_input->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(input_cnode); | MS_EXCEPTION_IF_NULL(input_cnode); | ||||
| auto double_in_eltwise_input = input_cnode->input(1); | auto double_in_eltwise_input = input_cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(double_in_eltwise_input); | |||||
| if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) || | if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) || | ||||
| fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) { | fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) { | ||||
| return; | return; | ||||
| @@ -36,6 +36,7 @@ void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNod | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| std::unordered_set<AnfNodePtr> record{cnode}; | std::unordered_set<AnfNodePtr> record{cnode}; | ||||
| auto eltwise_input = cnode->input(1); | auto eltwise_input = cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(eltwise_input); | |||||
| if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || | if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || | ||||
| fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { | fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { | ||||
| return; | return; | ||||
| @@ -35,6 +35,7 @@ void ConvBnReduceFusionPass::MatchConvBnreduce(const CNodePtr &cnode, const sess | |||||
| auto manager = kernel_graph.manager(); | auto manager = kernel_graph.manager(); | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| auto conv = cnode->input(1); | auto conv = cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(conv); | |||||
| if (conv->isa<CNode>() && AnfAlgo::GetCNodeName(conv) == prim::kPrimConv2D->name()) { | if (conv->isa<CNode>() && AnfAlgo::GetCNodeName(conv) == prim::kPrimConv2D->name()) { | ||||
| std::vector<int> output_used_num{SizeToInt(manager->node_users()[conv].size())}; | std::vector<int> output_used_num{SizeToInt(manager->node_users()[conv].size())}; | ||||
| AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), conv); | AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), conv); | ||||
| @@ -35,6 +35,7 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| std::unordered_set<AnfNodePtr> record{cnode}; | std::unordered_set<AnfNodePtr> record{cnode}; | ||||
| auto eltwise_input = cnode->input(1); | auto eltwise_input = cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(eltwise_input); | |||||
| if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { | if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { | ||||
| (void)record.insert(eltwise_input); | (void)record.insert(eltwise_input); | ||||
| } else { | } else { | ||||
| @@ -43,6 +44,7 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con | |||||
| auto input_cnode = eltwise_input->cast<CNodePtr>(); | auto input_cnode = eltwise_input->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(input_cnode); | MS_EXCEPTION_IF_NULL(input_cnode); | ||||
| auto double_in_eltwise_input = input_cnode->input(1); | auto double_in_eltwise_input = input_cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(double_in_eltwise_input); | |||||
| if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) || | if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) || | ||||
| fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) { | fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) { | ||||
| return; | return; | ||||
| @@ -44,6 +44,7 @@ void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, con | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(eltwise_input); | |||||
| if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || | if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || | ||||
| fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { | fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { | ||||
| return; | return; | ||||
| @@ -35,6 +35,7 @@ void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::Kerne | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| std::unordered_set<AnfNodePtr> record{cnode}; | std::unordered_set<AnfNodePtr> record{cnode}; | ||||
| auto eltwise_input = cnode->input(1); | auto eltwise_input = cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(eltwise_input); | |||||
| while (CheckEltWiseNode(manager.get(), eltwise_input)) { | while (CheckEltWiseNode(manager.get(), eltwise_input)) { | ||||
| (void)record.insert(eltwise_input); | (void)record.insert(eltwise_input); | ||||
| if (record.size() == MAX_ELTWISE_SIZE) { | if (record.size() == MAX_ELTWISE_SIZE) { | ||||
| @@ -57,6 +58,7 @@ void EltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &ker | |||||
| std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return()); | std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return()); | ||||
| std::reverse(node_list.begin(), node_list.end()); | std::reverse(node_list.begin(), node_list.end()); | ||||
| for (auto &node : node_list) { | for (auto &node : node_list) { | ||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || | if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || | ||||
| AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { | AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { | ||||
| continue; | continue; | ||||
| @@ -25,6 +25,7 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| bool FusionBasePass::CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { | bool FusionBasePass::CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { | if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -38,6 +39,7 @@ bool FusionBasePass::CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePt | |||||
| bool FusionBasePass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { | bool FusionBasePass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { | if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -51,6 +53,7 @@ bool FusionBasePass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const A | |||||
| bool FusionBasePass::CheckMultiOutputEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { | bool FusionBasePass::CheckMultiOutputEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { | if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -55,6 +55,7 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap | |||||
| if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && | if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && | ||||
| AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { | AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { | ||||
| auto eltwise_input = cnode->input(1); | auto eltwise_input = cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(eltwise_input); | |||||
| if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) { | if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) { | ||||
| MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion); | MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion); | ||||
| } | } | ||||
| @@ -35,6 +35,7 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| std::unordered_set<AnfNodePtr> record{cnode}; | std::unordered_set<AnfNodePtr> record{cnode}; | ||||
| auto eltwise_input = cnode->input(1); | auto eltwise_input = cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(eltwise_input); | |||||
| if (CheckMultiOutputEltWiseNode(manager.get(), eltwise_input)) { | if (CheckMultiOutputEltWiseNode(manager.get(), eltwise_input)) { | ||||
| std::vector<int> output_used_num{SizeToInt(manager->node_users()[eltwise_input].size())}; | std::vector<int> output_used_num{SizeToInt(manager->node_users()[eltwise_input].size())}; | ||||
| AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input); | AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input); | ||||
| @@ -45,6 +45,7 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(eltwise_input); | |||||
| if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || | if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || | ||||
| fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { | fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { | ||||
| return; | return; | ||||
| @@ -44,6 +44,7 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(eltwise_input); | |||||
| if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || | if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || | ||||
| fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { | fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { | ||||
| return; | return; | ||||
| @@ -45,6 +45,7 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con | |||||
| write_input = input_cnode->input(1); | write_input = input_cnode->input(1); | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(write_input); | |||||
| if (!write_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(write_input) || | if (!write_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(write_input) || | ||||
| fusion_id_allocator->HasFusionIdAttr(write_input)) { | fusion_id_allocator->HasFusionIdAttr(write_input)) { | ||||
| return; | return; | ||||
| @@ -57,6 +58,7 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con | |||||
| conv_cnode->inputs().size() <= CONV_QUART_IN_INPUT_SIZE) { | conv_cnode->inputs().size() <= CONV_QUART_IN_INPUT_SIZE) { | ||||
| (void)record.insert(write_input); | (void)record.insert(write_input); | ||||
| auto conv_input = conv_cnode->input(1); | auto conv_input = conv_cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(conv_input); | |||||
| if (!conv_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(conv_input) || | if (!conv_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(conv_input) || | ||||
| fusion_id_allocator->HasFusionIdAttr(conv_input)) { | fusion_id_allocator->HasFusionIdAttr(conv_input)) { | ||||
| return; | return; | ||||
| @@ -206,6 +206,7 @@ void ReplaceOldNode(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusi | |||||
| void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, | void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, | ||||
| std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) { | std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) { | ||||
| MS_EXCEPTION_IF_NULL(buffer_fusion_infos); | MS_EXCEPTION_IF_NULL(buffer_fusion_infos); | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| auto nodes = TopoSort(kernel_graph->get_return()); | auto nodes = TopoSort(kernel_graph->get_return()); | ||||
| for (auto &node : nodes) { | for (auto &node : nodes) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| @@ -231,6 +232,7 @@ void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph, | |||||
| auto fusion_info = buffer_fusion_info.second; | auto fusion_info = buffer_fusion_info.second; | ||||
| for (const auto &node : fusion_info.anf_nodes) { | for (const auto &node : fusion_info.anf_nodes) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) { | for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) { | ||||
| auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0); | auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0); | ||||
| if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) == | if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) == | ||||
| @@ -253,6 +255,14 @@ bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) { | |||||
| auto getitem2 = node2->cast<CNodePtr>(); | auto getitem2 = node2->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(getitem1); | MS_EXCEPTION_IF_NULL(getitem1); | ||||
| MS_EXCEPTION_IF_NULL(getitem2); | MS_EXCEPTION_IF_NULL(getitem2); | ||||
| if (getitem1->size() < kTupleGetItemInputSize) { | |||||
| MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1[" | |||||
| << getitem1->DebugString() << "]"; | |||||
| } | |||||
| if (getitem2->size() < kTupleGetItemInputSize) { | |||||
| MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1[" | |||||
| << getitem2->DebugString() << "]"; | |||||
| } | |||||
| auto output_idx1 = GetValue<int>(GetValueNode(getitem1->input(2))); | auto output_idx1 = GetValue<int>(GetValueNode(getitem1->input(2))); | ||||
| auto output_idx2 = GetValue<int>(GetValueNode(getitem2->input(2))); | auto output_idx2 = GetValue<int>(GetValueNode(getitem2->input(2))); | ||||
| return output_idx1 < output_idx2; | return output_idx1 < output_idx2; | ||||
| @@ -285,6 +295,7 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, | |||||
| [](const std::pair<AnfNodePtr, int> &use_node) { return use_node.first; }); | [](const std::pair<AnfNodePtr, int> &use_node) { return use_node.first; }); | ||||
| std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); | std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); | ||||
| for (auto getitem : tuple_getitem_nodes) { | for (auto getitem : tuple_getitem_nodes) { | ||||
| MS_EXCEPTION_IF_NULL(getitem); | |||||
| auto getitem_ptr = getitem->cast<CNodePtr>(); | auto getitem_ptr = getitem->cast<CNodePtr>(); | ||||
| auto input2 = getitem_ptr->input(2); | auto input2 = getitem_ptr->input(2); | ||||
| auto output_idx = GetValue<int>(GetValueNode(input2)); | auto output_idx = GetValue<int>(GetValueNode(input2)); | ||||
| @@ -313,6 +324,7 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<A | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| for (size_t idx = 0; idx < outputs_list.size(); ++idx) { | for (size_t idx = 0; idx < outputs_list.size(); ++idx) { | ||||
| auto output = outputs_list[idx]; | auto output = outputs_list[idx]; | ||||
| MS_EXCEPTION_IF_NULL(output); | |||||
| if (output->isa<CNode>() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { | if (output->isa<CNode>() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { | ||||
| auto real_output = AnfAlgo::VisitKernel(output, 0); | auto real_output = AnfAlgo::VisitKernel(output, 0); | ||||
| auto output_cnode = output->cast<CNodePtr>(); | auto output_cnode = output->cast<CNodePtr>(); | ||||
| @@ -393,6 +405,7 @@ bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph | |||||
| bool UbPatternFusion::ReplaceFusionOp(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos, | bool UbPatternFusion::ReplaceFusionOp(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos, | ||||
| int32_t fusion_id, const kernel::KernelModPtr &kernel_ptr, | int32_t fusion_id, const kernel::KernelModPtr &kernel_ptr, | ||||
| session::KernelGraph *kernel_graph) const { | session::KernelGraph *kernel_graph) const { | ||||
| MS_EXCEPTION_IF_NULL(buffer_fusion_infos); | |||||
| auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; | auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; | ||||
| auto buffer_fusion = CreateFusionOp(buffer_fusion_info.inputs_list, buffer_fusion_info.outputs_list, | auto buffer_fusion = CreateFusionOp(buffer_fusion_info.inputs_list, buffer_fusion_info.outputs_list, | ||||
| buffer_fusion_info.anf_nodes, kernel_graph); | buffer_fusion_info.anf_nodes, kernel_graph); | ||||
| @@ -1,157 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/ir_fusion/conv_bn_add_relu_fusion.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include <string> | |||||
| #include <tuple> | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "device/kernel_info.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| constexpr size_t kBn2AddReluOutputNum = 4; | |||||
| enum Bn2AddReluOutput { | |||||
| kBn2AddReluOutput = 0, | |||||
| kBn2AddReluRunningMean, | |||||
| kBn2AddReluRunningVariance, | |||||
| kBn2AddReluSaveInvVariance, | |||||
| }; | |||||
| std::tuple<CNodePtr, CNodePtr, CNodePtr, CNodePtr> GetUsedCNode(const AnfNodePtr &node) { | |||||
| auto relu_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kReluInputNum); | |||||
| MS_EXCEPTION_IF_NULL(relu_cnode); | |||||
| auto add_cnode = CheckAnfNodeIfCNodeAndInputSize(relu_cnode->input(1), kAddInputNum); | |||||
| MS_EXCEPTION_IF_NULL(add_cnode); | |||||
| auto add_input1_cnode = CheckAnfNodeIfCNodeAndInputSize(add_cnode->input(1), kTupleGetitemInputNum); | |||||
| MS_EXCEPTION_IF_NULL(add_input1_cnode); | |||||
| auto bn_cnode = CheckAnfNodeIfCNodeAndInputSize(add_input1_cnode->input(1), kBnInputNum); | |||||
| MS_EXCEPTION_IF_NULL(bn_cnode); | |||||
| auto conv_cnode = CheckAnfNodeIfCNodeAndInputSize(bn_cnode->input(kX), kConvInputNum); | |||||
| return std::make_tuple(conv_cnode, bn_cnode, add_cnode, relu_cnode); | |||||
| } | |||||
| void CreateOutputsOfBn2AddRelu(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_bn1_outputs, | |||||
| const CNodePtr &bn_node, const CNodePtr &add_node, const CNodePtr &relu_node, | |||||
| std::vector<AnfNodePtr> *bn2_add_relu_outputs) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(add_node); | |||||
| MS_EXCEPTION_IF_NULL(relu_node); | |||||
| MS_EXCEPTION_IF_NULL(bn_node); | |||||
| auto prim = std::make_shared<Primitive>(kBN2AddReluOpName); | |||||
| std::vector<AnfNodePtr> bn2_add_relu_inputs = {NewValueNode(prim)}; | |||||
| // The inputs of bn2_add_relu are from the outputs of conv_bn1, the 2nd input of add, and the 2nd to 5th inputs of bn | |||||
| (void)std::copy(conv_bn1_outputs.begin(), conv_bn1_outputs.end(), std::back_inserter(bn2_add_relu_inputs)); | |||||
| bn2_add_relu_inputs.push_back(add_node->input(2)); | |||||
| for (size_t i = kX + 1; i <= kVariance; i++) { | |||||
| bn2_add_relu_inputs.push_back(bn_node->input(i)); | |||||
| } | |||||
| auto bn2_add_relu_cnode = func_graph->NewCNode(bn2_add_relu_inputs); | |||||
| MS_EXCEPTION_IF_NULL(bn2_add_relu_cnode); | |||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||||
| bn2_add_relu_cnode->set_kernel_info(kernel_info); | |||||
| // Set attr for bn2_add_relu | |||||
| AnfAlgo::CopyNodeAttrs(bn_node, bn2_add_relu_cnode); | |||||
| AnfAlgo::CopyNodeAttr("epsilon", "eps", bn_node, bn2_add_relu_cnode); | |||||
| // Set abstract of bn2_add_relu | |||||
| auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_node->abstract()); | |||||
| MS_EXCEPTION_IF_NULL(bn_abstract_tuple); | |||||
| if (bn_abstract_tuple->elements().size() != kBnOutputNum) { | |||||
| MS_LOG(EXCEPTION) << "Abstract tuple size of FusedBatchNorm must be " << kBnOutputNum << ", but it is " | |||||
| << bn_abstract_tuple->elements().size(); | |||||
| } | |||||
| auto relu_abstract = relu_node->abstract(); | |||||
| MS_EXCEPTION_IF_NULL(relu_abstract); | |||||
| // The abstracts of node bn2_add_relu are from the some abstracts of bn and relu nodes. | |||||
| AbstractBasePtrList bn2_add_relu_abstract_list{relu_abstract, bn_abstract_tuple->elements()[kRunningMean], | |||||
| bn_abstract_tuple->elements()[kRunningVariance], | |||||
| bn_abstract_tuple->elements()[kSaveInvVariance]}; | |||||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(bn2_add_relu_abstract_list); | |||||
| MS_EXCEPTION_IF_NULL(abstract_tuple); | |||||
| bn2_add_relu_cnode->set_abstract(abstract_tuple); | |||||
| CreateMultipleOutputsOfAnfNode(func_graph, bn2_add_relu_cnode, kBn2AddReluOutputNum, bn2_add_relu_outputs); | |||||
| } | |||||
| } // namespace | |||||
| const BaseRef ConvBnAddReluFusion::DefinePattern() const { | |||||
| VarPtr X = std::make_shared<Var>(); | |||||
| MS_EXCEPTION_IF_NULL(X); | |||||
| VarPtr W = std::make_shared<Var>(); | |||||
| MS_EXCEPTION_IF_NULL(W); | |||||
| VarPtr Ys = std::make_shared<SeqVar>(); | |||||
| MS_EXCEPTION_IF_NULL(Ys); | |||||
| VarPtr Zs = std::make_shared<SeqVar>(); | |||||
| MS_EXCEPTION_IF_NULL(Zs); | |||||
| return VectorRef( | |||||
| {prim::kPrimRelu, | |||||
| PatternListType( | |||||
| {prim::kPrimTensorAdd, | |||||
| PatternListType({prim::kPrimTupleGetItem, | |||||
| PatternListType({prim::kPrimFusedBatchNorm, PatternListType({prim::kPrimConv2D, Ys}), Zs}), | |||||
| W}), | |||||
| X})}); | |||||
| } | |||||
| const AnfNodePtr ConvBnAddReluFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| CNodePtr conv_cnode = nullptr; | |||||
| CNodePtr bn_cnode = nullptr; | |||||
| CNodePtr add_cnode = nullptr; | |||||
| CNodePtr relu_cnode = nullptr; | |||||
| std::tie(conv_cnode, bn_cnode, add_cnode, relu_cnode) = GetUsedCNode(node); | |||||
| // Create conv_bn1 node and get outputs of conv_bn1 | |||||
| std::vector<AnfNodePtr> conv_bn1_outputs; | |||||
| CreateOutputsOfConvBn1(func_graph, conv_cnode, bn_cnode, &conv_bn1_outputs); | |||||
| if (conv_bn1_outputs.size() != kConvBn1OutputNum) { | |||||
| MS_LOG(EXCEPTION) << "The output size of node conv_bn1 must be " << kConvBn1OutputNum << ", but it is " | |||||
| << conv_bn1_outputs.size(); | |||||
| } | |||||
| // Replace conv_node with the output 0 of conv_bn1 directly because the conv node may be used as input by others | |||||
| (void)manager->Replace(conv_cnode, conv_bn1_outputs[kData]); | |||||
| // Create bn2_add_relu node and get outputs of bn2_add_relu | |||||
| std::vector<AnfNodePtr> bn2_add_relu_outputs; | |||||
| CreateOutputsOfBn2AddRelu(func_graph, conv_bn1_outputs, bn_cnode, add_cnode, relu_cnode, &bn2_add_relu_outputs); | |||||
| if (bn2_add_relu_outputs.size() != kBn2AddReluOutputNum) { | |||||
| MS_LOG(EXCEPTION) << "The output size of node bn2_add_relu must be " << kBn2AddReluOutputNum << ", but it is " | |||||
| << bn2_add_relu_outputs.size(); | |||||
| } | |||||
| // Create a make_tuple to replace the bn node here, the outputs are from node bn2_add_relu and conv_bn1. | |||||
| std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), | |||||
| bn2_add_relu_outputs[kBn2AddReluOutput], | |||||
| bn2_add_relu_outputs[kBn2AddReluRunningMean], | |||||
| bn2_add_relu_outputs[kBn2AddReluRunningVariance], | |||||
| conv_bn1_outputs[kMean], | |||||
| bn2_add_relu_outputs[kBn2AddReluSaveInvVariance]}; | |||||
| auto make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||||
| (void)manager->Replace(bn_cnode, make_tuple); | |||||
| return bn2_add_relu_outputs[kBn2AddReluOutput]; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -1,34 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_ADD_RELU_FUSION_H_ | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_ADD_RELU_FUSION_H_ | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| #include "pre_activate/common/helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class ConvBnAddReluFusion : public PatternProcessPass { | |||||
| public: | |||||
| explicit ConvBnAddReluFusion(bool multigraph = true) : PatternProcessPass("conv_bn_add_relu_fusion", multigraph) {} | |||||
| ~ConvBnAddReluFusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_ADD_RELU_FUSION_H_ | |||||
| @@ -1,93 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/ir_fusion/conv_bn_fusion.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "device/kernel_info.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| const BaseRef ConvBnFusion::DefinePattern() const { | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| MS_EXCEPTION_IF_NULL(Xs); | |||||
| VarPtr Ys = std::make_shared<SeqVar>(); | |||||
| MS_EXCEPTION_IF_NULL(Ys); | |||||
| return VectorRef({prim::kPrimFusedBatchNorm, PatternListType({prim::kPrimConv2D, Xs}), Ys}); | |||||
| } | |||||
| const AnfNodePtr ConvBnFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!node->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "The bn node is expected to be a cnode"; | |||||
| } | |||||
| auto bn_cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(bn_cnode); | |||||
| if (bn_cnode->inputs().size() < kVariance + 1) { | |||||
| auto op_name = AnfAlgo::GetCNodeName(bn_cnode); | |||||
| MS_LOG(EXCEPTION) << "op[" << op_name << "] has less than " << kVariance + 1 << " inputs."; | |||||
| } | |||||
| AnfNodePtr conv_node = bn_cnode->input(kX); | |||||
| MS_EXCEPTION_IF_NULL(conv_node); | |||||
| if (!conv_node->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "The conv node is expected to be a cnode"; | |||||
| } | |||||
| auto conv_cnode = conv_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(conv_cnode); | |||||
| auto manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| // Create conv_bn1 node and get outputs of conv_bn1 | |||||
| std::vector<AnfNodePtr> conv_bn1_outputs; | |||||
| CreateOutputsOfConvBn1(func_graph, conv_cnode, bn_cnode, &conv_bn1_outputs); | |||||
| if (conv_bn1_outputs.size() != kConvBn1OutputNum) { | |||||
| MS_LOG(EXCEPTION) << "The output size of node conv_bn1 must be " << kConvBn1OutputNum << ", but it is " | |||||
| << conv_bn1_outputs.size(); | |||||
| } | |||||
| // Replace conv_node with the output 0 of conv_bn1 directly because the conv node may be used as input by other | |||||
| (void)manager->Replace(conv_node, conv_bn1_outputs[kData]); | |||||
| // Create bn2 node and get outputs of bn2 | |||||
| std::vector<AnfNodePtr> bn2_outputs; | |||||
| std::vector<AnfNodePtr> bn1_outputs = {conv_bn1_outputs[2], conv_bn1_outputs[1]}; | |||||
| CreateOutputsOfFusedBn2(func_graph, bn1_outputs, bn_cnode, &bn2_outputs); | |||||
| if (bn2_outputs.size() != kBN2OutputNum) { | |||||
| MS_LOG(EXCEPTION) << "The output size of node fusedbn2 must be " << kBN2OutputNum << ", but it is " | |||||
| << bn2_outputs.size(); | |||||
| } | |||||
| // Create bn3 node and get outputs of bn3 | |||||
| std::vector<AnfNodePtr> bn3_outputs; | |||||
| CreateOutputsOfFusedBn3(func_graph, conv_bn1_outputs[0], bn1_outputs, bn2_outputs, bn_cnode, &bn3_outputs); | |||||
| if (bn3_outputs.size() != kBN3OutputNum) { | |||||
| MS_LOG(EXCEPTION) << "The output size of node fusedbn3 must be " << kBN3OutputNum << ", but it is " | |||||
| << bn3_outputs.size(); | |||||
| } | |||||
| // Return a make_tuple to replace the bn node here, the outputs are from node bn2 and conv_bn1. | |||||
| std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), | |||||
| bn3_outputs[0], | |||||
| bn2_outputs[1], | |||||
| bn2_outputs[2], | |||||
| conv_bn1_outputs[2], | |||||
| bn2_outputs[0]}; | |||||
| return func_graph->NewCNode(make_tuple_inputs); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -1,34 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_FUSION_H_ | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_FUSION_H_ | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| #include "pre_activate/common/helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class ConvBnFusion : public PatternProcessPass { | |||||
| public: | |||||
| explicit ConvBnFusion(bool multigraph = true) : PatternProcessPass("conv_bn_fusion", multigraph) {} | |||||
| ~ConvBnFusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_FUSION_H_ | |||||
| @@ -1,140 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/ir_fusion/conv_bn_relu_fusion.h" | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <algorithm> | |||||
| #include <tuple> | |||||
| #include "utils/utils.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "common/utils.h" | |||||
| #include "device/kernel_info.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| std::tuple<CNodePtr, CNodePtr, CNodePtr> GetPrevNodes(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto relu_node = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(relu_node); | |||||
| if (relu_node->inputs().size() < kReluInputNum) { | |||||
| MS_LOG(EXCEPTION) << "relu has wrong input size"; | |||||
| } | |||||
| auto tuple_getitem_anf = relu_node->input(1); | |||||
| MS_EXCEPTION_IF_NULL(tuple_getitem_anf); | |||||
| auto tuple_getitem = tuple_getitem_anf->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tuple_getitem); | |||||
| if (tuple_getitem->inputs().size() < kTupleGetitemInputNum) { | |||||
| MS_LOG(EXCEPTION) << "tuple getitem has wrong input size"; | |||||
| } | |||||
| auto bn_node_anf = tuple_getitem->input(1); | |||||
| MS_EXCEPTION_IF_NULL(bn_node_anf); | |||||
| auto bn_node = bn_node_anf->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(bn_node); | |||||
| if (bn_node->inputs().size() < kBnInputNum) { | |||||
| MS_LOG(EXCEPTION) << "bn_node has wrong input size"; | |||||
| } | |||||
| auto conv_node_anf = bn_node->input(1); | |||||
| MS_EXCEPTION_IF_NULL(conv_node_anf); | |||||
| CNodePtr conv_node = conv_node_anf->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(conv_node); | |||||
| return std::make_tuple(bn_node, bn_node, conv_node); | |||||
| } | |||||
| void CreateOutputsOfBn2Relu(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_bn1_outputs, | |||||
| const CNodePtr &bn_node, const CNodePtr &relu_node, | |||||
| std::vector<AnfNodePtr> *bn2_relu_outputs) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(bn_node); | |||||
| MS_EXCEPTION_IF_NULL(relu_node); | |||||
| // The inputs of bn2_relu are from the outputs of conv_bn1 and the 2nd to 5th inputs of bn | |||||
| std::vector<AnfNodePtr> bn2_relu_inputs = {NewValueNode(std::make_shared<Primitive>(kBN2ReLUOpName))}; | |||||
| (void)std::copy(conv_bn1_outputs.begin(), conv_bn1_outputs.end(), std::back_inserter(bn2_relu_inputs)); | |||||
| for (size_t i = 2; i <= 5; i++) { | |||||
| bn2_relu_inputs.push_back(bn_node->input(i)); | |||||
| } | |||||
| auto bn2_relu = func_graph->NewCNode(bn2_relu_inputs); | |||||
| MS_EXCEPTION_IF_NULL(bn2_relu); | |||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||||
| bn2_relu->set_kernel_info(kernel_info); | |||||
| auto types = {AnfAlgo::GetOutputInferDataType(relu_node, 0), AnfAlgo::GetOutputInferDataType(bn_node, 1), | |||||
| AnfAlgo::GetOutputInferDataType(bn_node, 2), AnfAlgo::GetOutputInferDataType(bn_node, 4)}; | |||||
| auto shapes = {AnfAlgo::GetOutputInferShape(relu_node, 0), AnfAlgo::GetOutputInferShape(bn_node, 1), | |||||
| AnfAlgo::GetOutputInferShape(bn_node, 2), AnfAlgo::GetOutputInferShape(bn_node, 4)}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn2_relu.get()); | |||||
| // Set attr for bn2_add_relu | |||||
| AnfAlgo::CopyNodeAttrs(bn_node, bn2_relu); | |||||
| AnfAlgo::CopyNodeAttr("epsilon", "eps", bn_node, bn2_relu); | |||||
| CreateMultipleOutputsOfAnfNode(func_graph, bn2_relu, kBn2ReluOutputNum, bn2_relu_outputs); | |||||
| } | |||||
| } // namespace | |||||
| const BaseRef ConvBnReluFusion::DefinePattern() const { | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| VarPtr Ys = std::make_shared<SeqVar>(); | |||||
| VarPtr Z = std::make_shared<Var>(); | |||||
| MS_EXCEPTION_IF_NULL(Xs); | |||||
| MS_EXCEPTION_IF_NULL(Ys); | |||||
| MS_EXCEPTION_IF_NULL(Z); | |||||
| return VectorRef( | |||||
| {prim::kPrimRelu, | |||||
| PatternListType({prim::kPrimTupleGetItem, | |||||
| PatternListType({prim::kPrimFusedBatchNorm, PatternListType({prim::kPrimConv2D, Xs}), Ys}), Z})}); | |||||
| } | |||||
| const AnfNodePtr ConvBnReluFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| CNodePtr relu_node = nullptr; | |||||
| CNodePtr bn_node = nullptr; | |||||
| CNodePtr conv_node = nullptr; | |||||
| std::tie(relu_node, bn_node, conv_node) = GetPrevNodes(node); | |||||
| auto manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| std::vector<AnfNodePtr> conv_bn1_outputs; | |||||
| CreateOutputsOfConvBn1(func_graph, conv_node, bn_node, &conv_bn1_outputs); | |||||
| if (conv_bn1_outputs.size() != kConvBn1OutputNum) { | |||||
| MS_LOG(EXCEPTION) << "conv_bn1 outputs has wrong size: " << conv_bn1_outputs.size(); | |||||
| } | |||||
| (void)manager->Replace(conv_node, conv_bn1_outputs[0]); | |||||
| std::vector<AnfNodePtr> bn2_relu_outputs; | |||||
| CreateOutputsOfBn2Relu(func_graph, conv_bn1_outputs, bn_node, relu_node, &bn2_relu_outputs); | |||||
| if (bn2_relu_outputs.size() != kBn2ReluOutputNum) { | |||||
| MS_LOG(EXCEPTION) << "bn2_relu outputs has wrong size: " << bn2_relu_outputs.size(); | |||||
| } | |||||
| std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), | |||||
| bn2_relu_outputs[0], | |||||
| bn2_relu_outputs[1], | |||||
| bn2_relu_outputs[2], | |||||
| conv_bn1_outputs[2], | |||||
| bn2_relu_outputs[3]}; | |||||
| auto make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||||
| (void)manager->Replace(bn_node, make_tuple); | |||||
| return bn2_relu_outputs[0]; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -1,33 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_RELU_FUSION_H_ | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_RELU_FUSION_H_ | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| #include "pre_activate/common/helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class ConvBnReluFusion : public PatternProcessPass { | |||||
| public: | |||||
| explicit ConvBnReluFusion(bool multigraph = true) : PatternProcessPass("conv_bn_relu_fusion", multigraph) {} | |||||
| ~ConvBnReluFusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_RELU_FUSION_H_ | |||||
| @@ -28,6 +28,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id(); | MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id(); | ||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| @@ -37,7 +38,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern | |||||
| save_graphs_path = "."; | save_graphs_path = "."; | ||||
| } | } | ||||
| if (save_graphs) { | if (save_graphs) { | ||||
| std::string file_path = save_graphs_path + "/" + "hwopt_common_before.ir"; | |||||
| std::string file_path = | |||||
| save_graphs_path + "/hwopt_common_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_path, kernel_graph); | DumpIR(file_path, kernel_graph); | ||||
| } | } | ||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | auto optimizer = std::make_shared<GraphOptimizer>(); | ||||
| @@ -51,7 +53,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern | |||||
| (void)optimizer->Optimize(kernel_graph); | (void)optimizer->Optimize(kernel_graph); | ||||
| kernel_graph->SetExecOrderByDefault(); | kernel_graph->SetExecOrderByDefault(); | ||||
| if (save_graphs) { | if (save_graphs) { | ||||
| std::string file_path = save_graphs_path + "/" + "hwopt_common_after.ir"; | |||||
| std::string file_path = | |||||
| save_graphs_path + "/hwopt_common_after_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_path, kernel_graph); | DumpIR(file_path, kernel_graph); | ||||
| } | } | ||||
| } | } | ||||
| @@ -45,6 +45,7 @@ bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodeP | |||||
| std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return()); | std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return()); | ||||
| std::map<AnfNodePtr, std::set<AnfNodePtr>> control_depend_map; | std::map<AnfNodePtr, std::set<AnfNodePtr>> control_depend_map; | ||||
| for (auto &nd : node_list) { | for (auto &nd : node_list) { | ||||
| MS_EXCEPTION_IF_NULL(nd); | |||||
| if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { | if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { | ||||
| auto control_depend = nd->cast<CNodePtr>(); | auto control_depend = nd->cast<CNodePtr>(); | ||||
| auto prior_node = control_depend->input(kControlDependPriorIndex); | auto prior_node = control_depend->input(kControlDependPriorIndex); | ||||
| @@ -157,6 +158,7 @@ const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const An | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputNum); | auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputNum); | ||||
| MS_EXCEPTION_IF_NULL(transop_cnode); | |||||
| auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(kCastInputNum - 1), kDependInputNum); | auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(kCastInputNum - 1), kDependInputNum); | ||||
| auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputNum); | auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputNum); | ||||
| MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependInputNum - 1)); | MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependInputNum - 1)); | ||||
| @@ -545,14 +547,22 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) { | |||||
| if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) { | if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) { | ||||
| auto a_node = utils::cast<AnfNodePtr>(a); | auto a_node = utils::cast<AnfNodePtr>(a); | ||||
| auto b_node = utils::cast<AnfNodePtr>(b); | auto b_node = utils::cast<AnfNodePtr>(b); | ||||
| MS_EXCEPTION_IF_NULL(a_node); | |||||
| MS_EXCEPTION_IF_NULL(b_node); | |||||
| if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) { | if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) { | ||||
| auto a_value_node = a_node->cast<ValueNodePtr>(); | auto a_value_node = a_node->cast<ValueNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(a_value_node); | |||||
| auto a_value = a_value_node->value(); | auto a_value = a_value_node->value(); | ||||
| MS_EXCEPTION_IF_NULL(a_value); | |||||
| auto a_prim = a_value->cast<PrimitivePtr>(); | auto a_prim = a_value->cast<PrimitivePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(a_prim); | |||||
| auto b_value_node = b_node->cast<ValueNodePtr>(); | auto b_value_node = b_node->cast<ValueNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(b_value_node); | |||||
| auto b_value = b_value_node->value(); | auto b_value = b_value_node->value(); | ||||
| MS_EXCEPTION_IF_NULL(b_value); | |||||
| auto b_prim = b_value->cast<PrimitivePtr>(); | auto b_prim = b_value->cast<PrimitivePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(b_prim); | |||||
| return a_prim->name() == b_prim->name(); | return a_prim->name() == b_prim->name(); | ||||
| } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) { | } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) { | ||||
| @@ -1,77 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/backend_common_test.h" | |||||
| #include "operator/ops.h" | |||||
| #include "debug/anf_ir_dump.h" | |||||
| #include "common/py_func_graph_fetcher.h" | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| #include "pre_activate/common/pass_manager.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "device/kernel_info.h" | |||||
| #define private public | |||||
| #define protected public | |||||
| #include "pre_activate/ascend/ir_fusion/conv_bn_fusion.h" | |||||
| #undef private | |||||
| #undef protected | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; | |||||
| class TestHWConvBnFusion : public BackendCommon { | |||||
| public: | |||||
| TestHWConvBnFusion() : getPyFun_("gtest_input.pre_activate.ir_fusion_test", true) {} | |||||
| ~TestHWConvBnFusion() override = default; | |||||
| UT::PyFuncGraphFetcher getPyFun_; | |||||
| }; | |||||
| TEST_F(TestHWConvBnFusion, test_conv_bn_fusion) { | |||||
| /* | |||||
| * def before(x, y): | |||||
| * conv_output = conv(x, y) | |||||
| * bn_output = bn(conv_output) | |||||
| * item0 = tuple_getitem(bn_output, 0) | |||||
| * item1 = tuple_getitem(bn_output, 3) | |||||
| * item2 = tuple_getitem(bn_output, 4) | |||||
| * res = make_tuple(item0, item1, item2) | |||||
| * return res | |||||
| */ | |||||
| getPyFun_.SetDoResolve(true); | |||||
| FuncGraphPtr g = getPyFun_.CallAndParseRet("test_conv_bn_fusion", "before"); | |||||
| std::vector<int> shp_x{32, 3, 224, 224}; | |||||
| std::vector<int> shp_w{64, 3, 7, 7}; | |||||
| std::vector<int> shp_b{64}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||||
| auto w_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_w); | |||||
| auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_b); | |||||
| AbstractBasePtrList args_spec_list{x_abstract, w_abstract, b_abstract, b_abstract, b_abstract, b_abstract}; | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto graph_optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pass_manager = std::make_shared<opt::PassManager>(); | |||||
| auto conv_bn_fusion_pass = std::make_shared<opt::ConvBnFusion>(); | |||||
| pass_manager->AddPass(conv_bn_fusion_pass); | |||||
| graph_optimizer->AddPassManager(pass_manager); | |||||
| auto new_g = graph_optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_conv_bn_fusion", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_g)); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -1,62 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/backend_common_test.h" | |||||
| #include "common/py_func_graph_fetcher.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "debug/anf_ir_dump.h" | |||||
| #include "kernel/kernel_build_info.h" | |||||
| #define private public | |||||
| #define protected public | |||||
| #include "pre_activate/ascend/ir_fusion/conv_bn_relu_fusion.h" | |||||
| #undef private | |||||
| #undef protected | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class TestHWConvBnReluFusion : public BackendCommon { | |||||
| public: | |||||
| TestHWConvBnReluFusion() : get_py_fun_("gtest_input.pre_activate.conv_bn_relu_fusion", true) {} | |||||
| ~TestHWConvBnReluFusion() override = default; | |||||
| UT::PyFuncGraphFetcher get_py_fun_; | |||||
| }; | |||||
| TEST_F(TestHWConvBnReluFusion, test_conv_bn_relu_fusion) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_conv_bn_relu_fusion", "before"); | |||||
| ASSERT_TRUE(g != nullptr); | |||||
| std::vector<int> shp_x{32, 3, 224, 224}; | |||||
| std::vector<int> shp_w{64, 3, 7, 7}; | |||||
| std::vector<int> shp_b{64}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||||
| auto w_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_w); | |||||
| auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_b); | |||||
| AbstractBasePtrList args_spec_list{x_abstract, w_abstract, b_abstract, b_abstract, b_abstract, b_abstract}; | |||||
| auto kernel_graph = GetKernelGraph(g, args_spec_list); | |||||
| // do bn_grad_split_pass | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| auto pass = std::make_shared<opt::ConvBnReluFusion>(); | |||||
| pm->AddPass(pass); | |||||
| optimizer->AddPassManager(pm); | |||||
| auto new_graph = optimizer->Optimize(kernel_graph); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_conv_bn_relu_fusion", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -1,71 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore.ops import Primitive | |||||
| from mindspore.ops import operations as P | |||||
| make_tuple = Primitive('make_tuple') | |||||
| tuple_getitem = Primitive('tuple_getitem') | |||||
| conv = P.Conv2D(out_channel=64, kernel_size=7, mode=1, pad_mode="valid", pad=0, stride=1, dilation=1, group=1) | |||||
| bn = P.FusedBatchNorm() | |||||
| relu = P.ReLU() | |||||
| conv_bn1 = Primitive('ConvBN1') | |||||
| bn2_relu = Primitive('BN2Relu') | |||||
| class FnDict: | |||||
| def __init__(self): | |||||
| self.fnDict = {} | |||||
| def __call__(self, fn): | |||||
| self.fnDict[fn.__name__] = fn | |||||
| def __getitem__(self, name): | |||||
| return self.fnDict[name] | |||||
| def test_conv_bn_relu_fusion(tag): | |||||
| """ test_conv_bn_relu_fusion """ | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(x, w, scale, b, mean, variance): | |||||
| conv_output = conv(x, w) | |||||
| bn_output = bn(conv_output, scale, b, mean, variance) | |||||
| item0 = tuple_getitem(bn_output, 0) | |||||
| item1 = tuple_getitem(bn_output, 3) | |||||
| item2 = tuple_getitem(bn_output, 4) | |||||
| output = make_tuple(relu(item0), item1, item2) | |||||
| res = tuple_getitem(output, 0) | |||||
| return res | |||||
| @fns | |||||
| def after(x, w, scale, b, mean, variance): | |||||
| conv_bn1_output = conv_bn1(x, w) | |||||
| conv_item0 = tuple_getitem(conv_bn1_output, 0) | |||||
| conv_item1 = tuple_getitem(conv_bn1_output, 1) | |||||
| conv_item2 = tuple_getitem(conv_bn1_output, 2) | |||||
| bn2_relu_output = bn2_relu(conv_item0, conv_item1, conv_item2, scale, b, mean, variance) | |||||
| bn2_relu_item0 = tuple_getitem(bn2_relu_output, 0) | |||||
| bn2_relu_item1 = tuple_getitem(bn2_relu_output, 1) | |||||
| bn2_relu_item2 = tuple_getitem(bn2_relu_output, 2) | |||||
| bn2_relu_item3 = tuple_getitem(bn2_relu_output, 3) | |||||
| new_make_tuple = make_tuple(bn2_relu_item0, bn2_relu_item1, bn2_relu_item2, conv_item2, bn2_relu_item3) | |||||
| item1 = tuple_getitem(new_make_tuple, 3) | |||||
| item2 = tuple_getitem(new_make_tuple, 4) | |||||
| output = make_tuple(bn2_relu_item0, item1, item2) | |||||
| return make_tuple(tuple_getitem(output, 0)) | |||||
| return fns[tag] | |||||
| @@ -49,104 +49,6 @@ class FnDict: | |||||
| return self.fnDict[name] | return self.fnDict[name] | ||||
| def test_conv_bn_fusion(tag): | |||||
| """ test_conv_bn_fusion """ | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(x, w, scale, b, mean, variance): | |||||
| conv_output = conv(x, w) | |||||
| bn_output = bn(conv_output, scale, b, mean, variance) | |||||
| item0 = tuple_getitem(bn_output, 0) | |||||
| item1 = tuple_getitem(bn_output, 3) | |||||
| item2 = tuple_getitem(bn_output, 4) | |||||
| output = make_tuple(item0, item1, item2) | |||||
| res = tuple_getitem(output, 0) | |||||
| return res | |||||
| @fns | |||||
| def after(x, w, scale, b, mean, variance): | |||||
| conv_bn1_output = conv_bn1(x, w) | |||||
| conv_item0 = tuple_getitem(conv_bn1_output, 0) | |||||
| conv_item1 = tuple_getitem(conv_bn1_output, 1) | |||||
| conv_item2 = tuple_getitem(conv_bn1_output, 2) | |||||
| bn2_output = fused_bn2(conv_item2, conv_item1, mean, variance) | |||||
| bn2_item0 = tuple_getitem(bn2_output, 0) | |||||
| bn2_item1 = tuple_getitem(bn2_output, 1) | |||||
| bn2_item2 = tuple_getitem(bn2_output, 2) | |||||
| bn3_output = fused_bn3(conv_item0, conv_item2, bn2_item0, scale, b) | |||||
| output = make_tuple(bn3_output, bn2_item1, bn2_item2, conv_item2, bn2_item0) | |||||
| item0 = tuple_getitem(output, 0) | |||||
| item1 = tuple_getitem(output, 3) | |||||
| item2 = tuple_getitem(output, 4) | |||||
| new_output = make_tuple(item0, item1, item2) | |||||
| return make_tuple(tuple_getitem(new_output, 0)) | |||||
| return fns[tag] | |||||
| def test_conv_bn_add_relu_fusion(tag): | |||||
| """ test_conv_bn_add_relu_fusion """ | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(x, w, scale, b, mean, variance, y): | |||||
| conv_output = conv(x, w) | |||||
| bn_output = bn(conv_output, scale, b, mean, variance) | |||||
| item0 = tuple_getitem(bn_output, 0) | |||||
| s = add(item0, y) | |||||
| res = relu(s) | |||||
| return res | |||||
| @fns | |||||
| def after(x, w, scale, b, mean, variance, y): | |||||
| conv_bn1_output = conv_bn1(x, w) | |||||
| conv_item0 = tuple_getitem(conv_bn1_output, 0) | |||||
| conv_item1 = tuple_getitem(conv_bn1_output, 1) | |||||
| conv_item2 = tuple_getitem(conv_bn1_output, 2) | |||||
| bn2_add_relu_output = bn2_add_relu(conv_item0, conv_item1, conv_item2, y, scale, b, mean, variance) | |||||
| bn2_add_relu_item0 = tuple_getitem(bn2_add_relu_output, 0) | |||||
| res = make_tuple(bn2_add_relu_item0) | |||||
| return res | |||||
| return fns[tag] | |||||
| def test_conv_bn_relu_fusion(tag): | |||||
| """ test_conv_bn_relu_fusion """ | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(x, w, scale, b, mean, variance): | |||||
| conv_output = conv(x, w) | |||||
| bn_output = bn(conv_output, scale, b, mean, variance) | |||||
| item0 = tuple_getitem(bn_output, 0) | |||||
| item1 = tuple_getitem(bn_output, 3) | |||||
| item2 = tuple_getitem(bn_output, 4) | |||||
| output = make_tuple(relu(item0), item1, item2) | |||||
| res = tuple_getitem(output, 0) | |||||
| return res | |||||
| @fns | |||||
| def after(x, w, scale, b, mean, variance): | |||||
| conv_bn1_output = conv_bn1(x, w) | |||||
| conv_item0 = tuple_getitem(conv_bn1_output, 0) | |||||
| conv_item1 = tuple_getitem(conv_bn1_output, 1) | |||||
| conv_item2 = tuple_getitem(conv_bn1_output, 2) | |||||
| bn2_relu_output = bn2_relu(conv_item0, conv_item1, conv_item2, scale, b, mean, variance) | |||||
| bn2_relu_item0 = tuple_getitem(bn2_relu_output, 0) | |||||
| bn2_relu_item1 = tuple_getitem(bn2_relu_output, 1) | |||||
| bn2_relu_item2 = tuple_getitem(bn2_relu_output, 2) | |||||
| bn2_relu_item3 = tuple_getitem(bn2_relu_output, 3) | |||||
| new_make_tuple = make_tuple(bn2_relu_item0, bn2_relu_item1, bn2_relu_item2, conv_item2, bn2_relu_item3) | |||||
| item1 = tuple_getitem(new_make_tuple, 3) | |||||
| item2 = tuple_getitem(new_make_tuple, 4) | |||||
| output = make_tuple(bn2_relu_item0, item1, item2) | |||||
| return make_tuple(tuple_getitem(output, 0)) | |||||
| return fns[tag] | |||||
| def test_bn_split(tag): | def test_bn_split(tag): | ||||
| """ test_split_bn_fusion """ | """ test_split_bn_fusion """ | ||||
| fns = FnDict() | fns = FnDict() | ||||