Browse Source

!2047 [Code review] check code files in pre_activate/ascend/buffer_fusion

Merge pull request !2047 from huanghui/code-review
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
928c25ebea
28 changed files with 61 additions and 803 deletions
  1. +4
    -2
      mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
  2. +6
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc
  3. +5
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc
  4. +2
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc
  5. +1
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc
  6. +1
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc
  7. +2
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.cc
  8. +1
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.cc
  9. +2
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.cc
  10. +3
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.cc
  11. +1
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc
  12. +1
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.cc
  13. +1
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc
  14. +1
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc
  15. +2
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc
  16. +13
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.cc
  17. +0
    -157
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/conv_bn_add_relu_fusion.cc
  18. +0
    -34
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/conv_bn_add_relu_fusion.h
  19. +0
    -93
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/conv_bn_fusion.cc
  20. +0
    -34
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/conv_bn_fusion.h
  21. +0
    -140
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/conv_bn_relu_fusion.cc
  22. +0
    -33
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/conv_bn_relu_fusion.h
  23. +5
    -2
      mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc
  24. +10
    -0
      mindspore/ccsrc/pre_activate/common/helper.cc
  25. +0
    -77
      tests/ut/cpp/pre_activate/ascend/ir_fusion/conv_bn_fusion_test.cc
  26. +0
    -62
      tests/ut/cpp/pre_activate/ascend/ir_fusion/conv_bn_relu_fusion_test.cc
  27. +0
    -71
      tests/ut/cpp/python_input/gtest_input/pre_activate/conv_bn_relu_fusion.py
  28. +0
    -98
      tests/ut/cpp/python_input/gtest_input/pre_activate/ir_fusion_test.py

+ 4
- 2
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc View File

@@ -346,7 +346,8 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap
save_graphs_path = "."; save_graphs_path = ".";
} }
if (save_graphs) { if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_ub_fusion_before.ir";
std::string file_path =
save_graphs_path + "/hwopt_d_ub_fusion_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph); DumpIR(file_path, kernel_graph);
} }
auto fusion_id_allocator = std::make_shared<FusionIdAllocator>(); auto fusion_id_allocator = std::make_shared<FusionIdAllocator>();
@@ -372,7 +373,8 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap
(void)optimizer->Optimize(kernel_graph); (void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault(); kernel_graph->SetExecOrderByDefault();
if (save_graphs) { if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_ub_fusion_after.ir";
std::string file_path =
save_graphs_path + "/hwopt_d_ub_fusion_after_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph); DumpIR(file_path, kernel_graph);
} }
} }


+ 6
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc View File

@@ -34,16 +34,22 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
MS_EXCEPTION_IF_NULL(candidate_fusion); MS_EXCEPTION_IF_NULL(candidate_fusion);
auto manager = kernel_graph.manager(); auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(relu_input);
auto add = relu_input->cast<CNodePtr>(); auto add = relu_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(add); MS_EXCEPTION_IF_NULL(add);
auto tuple_getitem = add->input(1); auto tuple_getitem = add->input(1);
MS_EXCEPTION_IF_NULL(tuple_getitem);
if (tuple_getitem->isa<CNode>() && AnfAlgo::GetCNodeName(tuple_getitem) == prim::kPrimTupleGetItem->name()) { if (tuple_getitem->isa<CNode>() && AnfAlgo::GetCNodeName(tuple_getitem) == prim::kPrimTupleGetItem->name()) {
auto getitem = tuple_getitem->cast<CNodePtr>(); auto getitem = tuple_getitem->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(getitem);
auto bnupdate = getitem->input(1); auto bnupdate = getitem->input(1);
MS_EXCEPTION_IF_NULL(bnupdate);
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) {
std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
for (auto out_getitem : manager->node_users()[bnupdate]) { for (auto out_getitem : manager->node_users()[bnupdate]) {
MS_EXCEPTION_IF_NULL(out_getitem.first);
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>(); auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
auto input2 = out_getitem_ptr->input(2); auto input2 = out_getitem_ptr->input(2);
auto output_idx = GetValue<int>(GetValueNode(input2)); auto output_idx = GetValue<int>(GetValueNode(input2));
output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size());


+ 5
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc View File

@@ -34,12 +34,17 @@ void BnupdateEltwiseFusionPass::MatchBnupdateRelu(const CNodePtr &cnode, const A
MS_EXCEPTION_IF_NULL(candidate_fusion); MS_EXCEPTION_IF_NULL(candidate_fusion);
auto manager = kernel_graph.manager(); auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(relu_input);
auto getitem = relu_input->cast<CNodePtr>(); auto getitem = relu_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(getitem);
auto bnupdate = getitem->input(1); auto bnupdate = getitem->input(1);
MS_EXCEPTION_IF_NULL(bnupdate);
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) {
std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
for (auto out_getitem : manager->node_users()[bnupdate]) { for (auto out_getitem : manager->node_users()[bnupdate]) {
MS_EXCEPTION_IF_NULL(out_getitem.first);
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>(); auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
auto input2 = out_getitem_ptr->input(2); auto input2 = out_getitem_ptr->input(2);
auto output_idx = GetValue<int>(GetValueNode(input2)); auto output_idx = GetValue<int>(GetValueNode(input2));
output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size());


+ 2
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc View File

@@ -35,6 +35,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
std::unordered_set<AnfNodePtr> record{cnode}; std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1); auto eltwise_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) {
(void)record.insert(eltwise_input); (void)record.insert(eltwise_input);
} else { } else {
@@ -43,6 +44,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
auto input_cnode = eltwise_input->cast<CNodePtr>(); auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode); MS_EXCEPTION_IF_NULL(input_cnode);
auto double_in_eltwise_input = input_cnode->input(1); auto double_in_eltwise_input = input_cnode->input(1);
MS_EXCEPTION_IF_NULL(double_in_eltwise_input);
if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) || if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) { fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) {
return; return;


+ 1
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc View File

@@ -36,6 +36,7 @@ void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNod
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
std::unordered_set<AnfNodePtr> record{cnode}; std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1); auto eltwise_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { fusion_id_allocator->HasFusionIdAttr(eltwise_input)) {
return; return;


+ 1
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc View File

@@ -35,6 +35,7 @@ void ConvBnReduceFusionPass::MatchConvBnreduce(const CNodePtr &cnode, const sess
auto manager = kernel_graph.manager(); auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
auto conv = cnode->input(1); auto conv = cnode->input(1);
MS_EXCEPTION_IF_NULL(conv);
if (conv->isa<CNode>() && AnfAlgo::GetCNodeName(conv) == prim::kPrimConv2D->name()) { if (conv->isa<CNode>() && AnfAlgo::GetCNodeName(conv) == prim::kPrimConv2D->name()) {
std::vector<int> output_used_num{SizeToInt(manager->node_users()[conv].size())}; std::vector<int> output_used_num{SizeToInt(manager->node_users()[conv].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), conv); AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), conv);


+ 2
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.cc View File

@@ -35,6 +35,7 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
std::unordered_set<AnfNodePtr> record{cnode}; std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1); auto eltwise_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) {
(void)record.insert(eltwise_input); (void)record.insert(eltwise_input);
} else { } else {
@@ -43,6 +44,7 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con
auto input_cnode = eltwise_input->cast<CNodePtr>(); auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode); MS_EXCEPTION_IF_NULL(input_cnode);
auto double_in_eltwise_input = input_cnode->input(1); auto double_in_eltwise_input = input_cnode->input(1);
MS_EXCEPTION_IF_NULL(double_in_eltwise_input);
if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) || if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) { fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) {
return; return;


+ 1
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.cc View File

@@ -44,6 +44,7 @@ void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, con
break; break;
} }
} }
MS_EXCEPTION_IF_NULL(eltwise_input);
if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { fusion_id_allocator->HasFusionIdAttr(eltwise_input)) {
return; return;


+ 2
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.cc View File

@@ -35,6 +35,7 @@ void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::Kerne
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
std::unordered_set<AnfNodePtr> record{cnode}; std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1); auto eltwise_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(eltwise_input);
while (CheckEltWiseNode(manager.get(), eltwise_input)) { while (CheckEltWiseNode(manager.get(), eltwise_input)) {
(void)record.insert(eltwise_input); (void)record.insert(eltwise_input);
if (record.size() == MAX_ELTWISE_SIZE) { if (record.size() == MAX_ELTWISE_SIZE) {
@@ -57,6 +58,7 @@ void EltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &ker
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return()); std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
std::reverse(node_list.begin(), node_list.end()); std::reverse(node_list.begin(), node_list.end());
for (auto &node : node_list) { for (auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue; continue;


+ 3
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.cc View File

@@ -25,6 +25,7 @@ namespace mindspore {
namespace opt { namespace opt {
bool FusionBasePass::CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { bool FusionBasePass::CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
return false; return false;
} }
@@ -38,6 +39,7 @@ bool FusionBasePass::CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePt


bool FusionBasePass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { bool FusionBasePass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
return false; return false;
} }
@@ -51,6 +53,7 @@ bool FusionBasePass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const A


bool FusionBasePass::CheckMultiOutputEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { bool FusionBasePass::CheckMultiOutputEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
return false; return false;
} }


+ 1
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc View File

@@ -55,6 +55,7 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
auto eltwise_input = cnode->input(1); auto eltwise_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) { if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) {
MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion); MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion);
} }


+ 1
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.cc View File

@@ -35,6 +35,7 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
std::unordered_set<AnfNodePtr> record{cnode}; std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1); auto eltwise_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (CheckMultiOutputEltWiseNode(manager.get(), eltwise_input)) { if (CheckMultiOutputEltWiseNode(manager.get(), eltwise_input)) {
std::vector<int> output_used_num{SizeToInt(manager->node_users()[eltwise_input].size())}; std::vector<int> output_used_num{SizeToInt(manager->node_users()[eltwise_input].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input); AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input);


+ 1
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc View File

@@ -45,6 +45,7 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se
break; break;
} }
} }
MS_EXCEPTION_IF_NULL(eltwise_input);
if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { fusion_id_allocator->HasFusionIdAttr(eltwise_input)) {
return; return;


+ 1
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc View File

@@ -44,6 +44,7 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const
break; break;
} }
} }
MS_EXCEPTION_IF_NULL(eltwise_input);
if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { fusion_id_allocator->HasFusionIdAttr(eltwise_input)) {
return; return;


+ 2
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc View File

@@ -45,6 +45,7 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
write_input = input_cnode->input(1); write_input = input_cnode->input(1);
} }


MS_EXCEPTION_IF_NULL(write_input);
if (!write_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(write_input) || if (!write_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(write_input) ||
fusion_id_allocator->HasFusionIdAttr(write_input)) { fusion_id_allocator->HasFusionIdAttr(write_input)) {
return; return;
@@ -57,6 +58,7 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
conv_cnode->inputs().size() <= CONV_QUART_IN_INPUT_SIZE) { conv_cnode->inputs().size() <= CONV_QUART_IN_INPUT_SIZE) {
(void)record.insert(write_input); (void)record.insert(write_input);
auto conv_input = conv_cnode->input(1); auto conv_input = conv_cnode->input(1);
MS_EXCEPTION_IF_NULL(conv_input);
if (!conv_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(conv_input) || if (!conv_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(conv_input) ||
fusion_id_allocator->HasFusionIdAttr(conv_input)) { fusion_id_allocator->HasFusionIdAttr(conv_input)) {
return; return;


+ 13
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.cc View File

@@ -206,6 +206,7 @@ void ReplaceOldNode(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusi
void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph,
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) { std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) {
MS_EXCEPTION_IF_NULL(buffer_fusion_infos); MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
MS_EXCEPTION_IF_NULL(kernel_graph);
auto nodes = TopoSort(kernel_graph->get_return()); auto nodes = TopoSort(kernel_graph->get_return());
for (auto &node : nodes) { for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
@@ -231,6 +232,7 @@ void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph,
auto fusion_info = buffer_fusion_info.second; auto fusion_info = buffer_fusion_info.second;
for (const auto &node : fusion_info.anf_nodes) { for (const auto &node : fusion_info.anf_nodes) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) { for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) {
auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0); auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0);
if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) == if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) ==
@@ -253,6 +255,14 @@ bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) {
auto getitem2 = node2->cast<CNodePtr>(); auto getitem2 = node2->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(getitem1); MS_EXCEPTION_IF_NULL(getitem1);
MS_EXCEPTION_IF_NULL(getitem2); MS_EXCEPTION_IF_NULL(getitem2);
if (getitem1->size() < kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1["
<< getitem1->DebugString() << "]";
}
if (getitem2->size() < kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1["
<< getitem2->DebugString() << "]";
}
auto output_idx1 = GetValue<int>(GetValueNode(getitem1->input(2))); auto output_idx1 = GetValue<int>(GetValueNode(getitem1->input(2)));
auto output_idx2 = GetValue<int>(GetValueNode(getitem2->input(2))); auto output_idx2 = GetValue<int>(GetValueNode(getitem2->input(2)));
return output_idx1 < output_idx2; return output_idx1 < output_idx2;
@@ -285,6 +295,7 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
[](const std::pair<AnfNodePtr, int> &use_node) { return use_node.first; }); [](const std::pair<AnfNodePtr, int> &use_node) { return use_node.first; });
std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare);
for (auto getitem : tuple_getitem_nodes) { for (auto getitem : tuple_getitem_nodes) {
MS_EXCEPTION_IF_NULL(getitem);
auto getitem_ptr = getitem->cast<CNodePtr>(); auto getitem_ptr = getitem->cast<CNodePtr>();
auto input2 = getitem_ptr->input(2); auto input2 = getitem_ptr->input(2);
auto output_idx = GetValue<int>(GetValueNode(input2)); auto output_idx = GetValue<int>(GetValueNode(input2));
@@ -313,6 +324,7 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<A
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
for (size_t idx = 0; idx < outputs_list.size(); ++idx) { for (size_t idx = 0; idx < outputs_list.size(); ++idx) {
auto output = outputs_list[idx]; auto output = outputs_list[idx];
MS_EXCEPTION_IF_NULL(output);
if (output->isa<CNode>() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { if (output->isa<CNode>() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) {
auto real_output = AnfAlgo::VisitKernel(output, 0); auto real_output = AnfAlgo::VisitKernel(output, 0);
auto output_cnode = output->cast<CNodePtr>(); auto output_cnode = output->cast<CNodePtr>();
@@ -393,6 +405,7 @@ bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph
bool UbPatternFusion::ReplaceFusionOp(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos, bool UbPatternFusion::ReplaceFusionOp(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos,
int32_t fusion_id, const kernel::KernelModPtr &kernel_ptr, int32_t fusion_id, const kernel::KernelModPtr &kernel_ptr,
session::KernelGraph *kernel_graph) const { session::KernelGraph *kernel_graph) const {
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id];
auto buffer_fusion = CreateFusionOp(buffer_fusion_info.inputs_list, buffer_fusion_info.outputs_list, auto buffer_fusion = CreateFusionOp(buffer_fusion_info.inputs_list, buffer_fusion_info.outputs_list,
buffer_fusion_info.anf_nodes, kernel_graph); buffer_fusion_info.anf_nodes, kernel_graph);


+ 0
- 157
mindspore/ccsrc/pre_activate/ascend/ir_fusion/conv_bn_add_relu_fusion.cc View File

@@ -1,157 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "pre_activate/ascend/ir_fusion/conv_bn_add_relu_fusion.h"
#include <memory>
#include <vector>
#include <algorithm>
#include <string>
#include <tuple>
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"

namespace mindspore {
namespace opt {
namespace {
constexpr size_t kBn2AddReluOutputNum = 4;
enum Bn2AddReluOutput {
kBn2AddReluOutput = 0,
kBn2AddReluRunningMean,
kBn2AddReluRunningVariance,
kBn2AddReluSaveInvVariance,
};

std::tuple<CNodePtr, CNodePtr, CNodePtr, CNodePtr> GetUsedCNode(const AnfNodePtr &node) {
auto relu_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kReluInputNum);
MS_EXCEPTION_IF_NULL(relu_cnode);
auto add_cnode = CheckAnfNodeIfCNodeAndInputSize(relu_cnode->input(1), kAddInputNum);
MS_EXCEPTION_IF_NULL(add_cnode);
auto add_input1_cnode = CheckAnfNodeIfCNodeAndInputSize(add_cnode->input(1), kTupleGetitemInputNum);
MS_EXCEPTION_IF_NULL(add_input1_cnode);
auto bn_cnode = CheckAnfNodeIfCNodeAndInputSize(add_input1_cnode->input(1), kBnInputNum);
MS_EXCEPTION_IF_NULL(bn_cnode);
auto conv_cnode = CheckAnfNodeIfCNodeAndInputSize(bn_cnode->input(kX), kConvInputNum);

return std::make_tuple(conv_cnode, bn_cnode, add_cnode, relu_cnode);
}

void CreateOutputsOfBn2AddRelu(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_bn1_outputs,
const CNodePtr &bn_node, const CNodePtr &add_node, const CNodePtr &relu_node,
std::vector<AnfNodePtr> *bn2_add_relu_outputs) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(add_node);
MS_EXCEPTION_IF_NULL(relu_node);
MS_EXCEPTION_IF_NULL(bn_node);
auto prim = std::make_shared<Primitive>(kBN2AddReluOpName);
std::vector<AnfNodePtr> bn2_add_relu_inputs = {NewValueNode(prim)};
// The inputs of bn2_add_relu are from the outputs of conv_bn1, the 2nd input of add, and the 2nd to 5th inputs of bn
(void)std::copy(conv_bn1_outputs.begin(), conv_bn1_outputs.end(), std::back_inserter(bn2_add_relu_inputs));
bn2_add_relu_inputs.push_back(add_node->input(2));
for (size_t i = kX + 1; i <= kVariance; i++) {
bn2_add_relu_inputs.push_back(bn_node->input(i));
}
auto bn2_add_relu_cnode = func_graph->NewCNode(bn2_add_relu_inputs);
MS_EXCEPTION_IF_NULL(bn2_add_relu_cnode);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
bn2_add_relu_cnode->set_kernel_info(kernel_info);

// Set attr for bn2_add_relu
AnfAlgo::CopyNodeAttrs(bn_node, bn2_add_relu_cnode);
AnfAlgo::CopyNodeAttr("epsilon", "eps", bn_node, bn2_add_relu_cnode);

// Set abstract of bn2_add_relu
auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_node->abstract());
MS_EXCEPTION_IF_NULL(bn_abstract_tuple);
if (bn_abstract_tuple->elements().size() != kBnOutputNum) {
MS_LOG(EXCEPTION) << "Abstract tuple size of FusedBatchNorm must be " << kBnOutputNum << ", but it is "
<< bn_abstract_tuple->elements().size();
}
auto relu_abstract = relu_node->abstract();
MS_EXCEPTION_IF_NULL(relu_abstract);
// The abstracts of node bn2_add_relu are from the some abstracts of bn and relu nodes.
AbstractBasePtrList bn2_add_relu_abstract_list{relu_abstract, bn_abstract_tuple->elements()[kRunningMean],
bn_abstract_tuple->elements()[kRunningVariance],
bn_abstract_tuple->elements()[kSaveInvVariance]};
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(bn2_add_relu_abstract_list);
MS_EXCEPTION_IF_NULL(abstract_tuple);
bn2_add_relu_cnode->set_abstract(abstract_tuple);

CreateMultipleOutputsOfAnfNode(func_graph, bn2_add_relu_cnode, kBn2AddReluOutputNum, bn2_add_relu_outputs);
}
} // namespace

const BaseRef ConvBnAddReluFusion::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
MS_EXCEPTION_IF_NULL(X);
VarPtr W = std::make_shared<Var>();
MS_EXCEPTION_IF_NULL(W);
VarPtr Ys = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Ys);
VarPtr Zs = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Zs);

return VectorRef(
{prim::kPrimRelu,
PatternListType(
{prim::kPrimTensorAdd,
PatternListType({prim::kPrimTupleGetItem,
PatternListType({prim::kPrimFusedBatchNorm, PatternListType({prim::kPrimConv2D, Ys}), Zs}),
W}),
X})});
}

const AnfNodePtr ConvBnAddReluFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
CNodePtr conv_cnode = nullptr;
CNodePtr bn_cnode = nullptr;
CNodePtr add_cnode = nullptr;
CNodePtr relu_cnode = nullptr;
std::tie(conv_cnode, bn_cnode, add_cnode, relu_cnode) = GetUsedCNode(node);
// Create conv_bn1 node and get outputs of conv_bn1
std::vector<AnfNodePtr> conv_bn1_outputs;
CreateOutputsOfConvBn1(func_graph, conv_cnode, bn_cnode, &conv_bn1_outputs);
if (conv_bn1_outputs.size() != kConvBn1OutputNum) {
MS_LOG(EXCEPTION) << "The output size of node conv_bn1 must be " << kConvBn1OutputNum << ", but it is "
<< conv_bn1_outputs.size();
}
// Replace conv_node with the output 0 of conv_bn1 directly because the conv node may be used as input by others
(void)manager->Replace(conv_cnode, conv_bn1_outputs[kData]);

// Create bn2_add_relu node and get outputs of bn2_add_relu
std::vector<AnfNodePtr> bn2_add_relu_outputs;
CreateOutputsOfBn2AddRelu(func_graph, conv_bn1_outputs, bn_cnode, add_cnode, relu_cnode, &bn2_add_relu_outputs);
if (bn2_add_relu_outputs.size() != kBn2AddReluOutputNum) {
MS_LOG(EXCEPTION) << "The output size of node bn2_add_relu must be " << kBn2AddReluOutputNum << ", but it is "
<< bn2_add_relu_outputs.size();
}

// Create a make_tuple to replace the bn node here, the outputs are from node bn2_add_relu and conv_bn1.
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple),
bn2_add_relu_outputs[kBn2AddReluOutput],
bn2_add_relu_outputs[kBn2AddReluRunningMean],
bn2_add_relu_outputs[kBn2AddReluRunningVariance],
conv_bn1_outputs[kMean],
bn2_add_relu_outputs[kBn2AddReluSaveInvVariance]};
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
(void)manager->Replace(bn_cnode, make_tuple);
return bn2_add_relu_outputs[kBn2AddReluOutput];
}
} // namespace opt
} // namespace mindspore

+ 0
- 34
mindspore/ccsrc/pre_activate/ascend/ir_fusion/conv_bn_add_relu_fusion.h View File

@@ -1,34 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_ADD_RELU_FUSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_ADD_RELU_FUSION_H_

#include "pre_activate/common/optimizer.h"
#include "pre_activate/common/helper.h"

namespace mindspore {
namespace opt {
class ConvBnAddReluFusion : public PatternProcessPass {
public:
explicit ConvBnAddReluFusion(bool multigraph = true) : PatternProcessPass("conv_bn_add_relu_fusion", multigraph) {}
~ConvBnAddReluFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_ADD_RELU_FUSION_H_

+ 0
- 93
mindspore/ccsrc/pre_activate/ascend/ir_fusion/conv_bn_fusion.cc View File

@@ -1,93 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "pre_activate/ascend/ir_fusion/conv_bn_fusion.h"
#include <memory>
#include <vector>
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"

namespace mindspore {
namespace opt {
const BaseRef ConvBnFusion::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Xs);
VarPtr Ys = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Ys);
return VectorRef({prim::kPrimFusedBatchNorm, PatternListType({prim::kPrimConv2D, Xs}), Ys});
}

const AnfNodePtr ConvBnFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "The bn node is expected to be a cnode";
}
auto bn_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(bn_cnode);
if (bn_cnode->inputs().size() < kVariance + 1) {
auto op_name = AnfAlgo::GetCNodeName(bn_cnode);
MS_LOG(EXCEPTION) << "op[" << op_name << "] has less than " << kVariance + 1 << " inputs.";
}
AnfNodePtr conv_node = bn_cnode->input(kX);
MS_EXCEPTION_IF_NULL(conv_node);
if (!conv_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "The conv node is expected to be a cnode";
}
auto conv_cnode = conv_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(conv_cnode);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
// Create conv_bn1 node and get outputs of conv_bn1
std::vector<AnfNodePtr> conv_bn1_outputs;
CreateOutputsOfConvBn1(func_graph, conv_cnode, bn_cnode, &conv_bn1_outputs);
if (conv_bn1_outputs.size() != kConvBn1OutputNum) {
MS_LOG(EXCEPTION) << "The output size of node conv_bn1 must be " << kConvBn1OutputNum << ", but it is "
<< conv_bn1_outputs.size();
}
// Replace conv_node with the output 0 of conv_bn1 directly because the conv node may be used as input by other
(void)manager->Replace(conv_node, conv_bn1_outputs[kData]);

// Create bn2 node and get outputs of bn2
std::vector<AnfNodePtr> bn2_outputs;
std::vector<AnfNodePtr> bn1_outputs = {conv_bn1_outputs[2], conv_bn1_outputs[1]};
CreateOutputsOfFusedBn2(func_graph, bn1_outputs, bn_cnode, &bn2_outputs);
if (bn2_outputs.size() != kBN2OutputNum) {
MS_LOG(EXCEPTION) << "The output size of node fusedbn2 must be " << kBN2OutputNum << ", but it is "
<< bn2_outputs.size();
}

// Create bn3 node and get outputs of bn3
std::vector<AnfNodePtr> bn3_outputs;
CreateOutputsOfFusedBn3(func_graph, conv_bn1_outputs[0], bn1_outputs, bn2_outputs, bn_cnode, &bn3_outputs);

if (bn3_outputs.size() != kBN3OutputNum) {
MS_LOG(EXCEPTION) << "The output size of node fusedbn3 must be " << kBN3OutputNum << ", but it is "
<< bn3_outputs.size();
}

// Return a make_tuple to replace the bn node here, the outputs are from node bn2 and conv_bn1.
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple),
bn3_outputs[0],
bn2_outputs[1],
bn2_outputs[2],
conv_bn1_outputs[2],
bn2_outputs[0]};

return func_graph->NewCNode(make_tuple_inputs);
}
} // namespace opt
} // namespace mindspore

+ 0
- 34
mindspore/ccsrc/pre_activate/ascend/ir_fusion/conv_bn_fusion.h View File

@@ -1,34 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_FUSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_FUSION_H_

#include "pre_activate/common/optimizer.h"
#include "pre_activate/common/helper.h"

namespace mindspore {
namespace opt {
class ConvBnFusion : public PatternProcessPass {
public:
explicit ConvBnFusion(bool multigraph = true) : PatternProcessPass("conv_bn_fusion", multigraph) {}
~ConvBnFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_FUSION_H_

+ 0
- 140
mindspore/ccsrc/pre_activate/ascend/ir_fusion/conv_bn_relu_fusion.cc View File

@@ -1,140 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/conv_bn_relu_fusion.h"

#include <vector>
#include <string>
#include <memory>
#include <algorithm>
#include <tuple>

#include "utils/utils.h"
#include "session/anf_runtime_algorithm.h"
#include "common/utils.h"
#include "device/kernel_info.h"

namespace mindspore {
namespace opt {
namespace {
std::tuple<CNodePtr, CNodePtr, CNodePtr> GetPrevNodes(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto relu_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(relu_node);
if (relu_node->inputs().size() < kReluInputNum) {
MS_LOG(EXCEPTION) << "relu has wrong input size";
}
auto tuple_getitem_anf = relu_node->input(1);
MS_EXCEPTION_IF_NULL(tuple_getitem_anf);
auto tuple_getitem = tuple_getitem_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getitem);
if (tuple_getitem->inputs().size() < kTupleGetitemInputNum) {
MS_LOG(EXCEPTION) << "tuple getitem has wrong input size";
}
auto bn_node_anf = tuple_getitem->input(1);
MS_EXCEPTION_IF_NULL(bn_node_anf);
auto bn_node = bn_node_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(bn_node);
if (bn_node->inputs().size() < kBnInputNum) {
MS_LOG(EXCEPTION) << "bn_node has wrong input size";
}
auto conv_node_anf = bn_node->input(1);
MS_EXCEPTION_IF_NULL(conv_node_anf);
CNodePtr conv_node = conv_node_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(conv_node);
return std::make_tuple(bn_node, bn_node, conv_node);
}

void CreateOutputsOfBn2Relu(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_bn1_outputs,
const CNodePtr &bn_node, const CNodePtr &relu_node,
std::vector<AnfNodePtr> *bn2_relu_outputs) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(bn_node);
MS_EXCEPTION_IF_NULL(relu_node);
// The inputs of bn2_relu are from the outputs of conv_bn1 and the 2nd to 5th inputs of bn
std::vector<AnfNodePtr> bn2_relu_inputs = {NewValueNode(std::make_shared<Primitive>(kBN2ReLUOpName))};
(void)std::copy(conv_bn1_outputs.begin(), conv_bn1_outputs.end(), std::back_inserter(bn2_relu_inputs));
for (size_t i = 2; i <= 5; i++) {
bn2_relu_inputs.push_back(bn_node->input(i));
}
auto bn2_relu = func_graph->NewCNode(bn2_relu_inputs);
MS_EXCEPTION_IF_NULL(bn2_relu);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
bn2_relu->set_kernel_info(kernel_info);
auto types = {AnfAlgo::GetOutputInferDataType(relu_node, 0), AnfAlgo::GetOutputInferDataType(bn_node, 1),
AnfAlgo::GetOutputInferDataType(bn_node, 2), AnfAlgo::GetOutputInferDataType(bn_node, 4)};
auto shapes = {AnfAlgo::GetOutputInferShape(relu_node, 0), AnfAlgo::GetOutputInferShape(bn_node, 1),
AnfAlgo::GetOutputInferShape(bn_node, 2), AnfAlgo::GetOutputInferShape(bn_node, 4)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn2_relu.get());
// Set attr for bn2_add_relu
AnfAlgo::CopyNodeAttrs(bn_node, bn2_relu);
AnfAlgo::CopyNodeAttr("epsilon", "eps", bn_node, bn2_relu);

CreateMultipleOutputsOfAnfNode(func_graph, bn2_relu, kBn2ReluOutputNum, bn2_relu_outputs);
}
} // namespace

const BaseRef ConvBnReluFusion::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
VarPtr Ys = std::make_shared<SeqVar>();
VarPtr Z = std::make_shared<Var>();
MS_EXCEPTION_IF_NULL(Xs);
MS_EXCEPTION_IF_NULL(Ys);
MS_EXCEPTION_IF_NULL(Z);
return VectorRef(
{prim::kPrimRelu,
PatternListType({prim::kPrimTupleGetItem,
PatternListType({prim::kPrimFusedBatchNorm, PatternListType({prim::kPrimConv2D, Xs}), Ys}), Z})});
}

const AnfNodePtr ConvBnReluFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);

CNodePtr relu_node = nullptr;
CNodePtr bn_node = nullptr;
CNodePtr conv_node = nullptr;
std::tie(relu_node, bn_node, conv_node) = GetPrevNodes(node);

auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);

std::vector<AnfNodePtr> conv_bn1_outputs;
CreateOutputsOfConvBn1(func_graph, conv_node, bn_node, &conv_bn1_outputs);
if (conv_bn1_outputs.size() != kConvBn1OutputNum) {
MS_LOG(EXCEPTION) << "conv_bn1 outputs has wrong size: " << conv_bn1_outputs.size();
}
(void)manager->Replace(conv_node, conv_bn1_outputs[0]);

std::vector<AnfNodePtr> bn2_relu_outputs;
CreateOutputsOfBn2Relu(func_graph, conv_bn1_outputs, bn_node, relu_node, &bn2_relu_outputs);
if (bn2_relu_outputs.size() != kBn2ReluOutputNum) {
MS_LOG(EXCEPTION) << "bn2_relu outputs has wrong size: " << bn2_relu_outputs.size();
}
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple),
bn2_relu_outputs[0],
bn2_relu_outputs[1],
bn2_relu_outputs[2],
conv_bn1_outputs[2],
bn2_relu_outputs[3]};
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
MS_EXCEPTION_IF_NULL(make_tuple);
(void)manager->Replace(bn_node, make_tuple);
return bn2_relu_outputs[0];
}
} // namespace opt
} // namespace mindspore

+ 0
- 33
mindspore/ccsrc/pre_activate/ascend/ir_fusion/conv_bn_relu_fusion.h View File

@@ -1,33 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_RELU_FUSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_RELU_FUSION_H_

#include "pre_activate/common/optimizer.h"
#include "pre_activate/common/helper.h"

namespace mindspore {
namespace opt {
class ConvBnReluFusion : public PatternProcessPass {
public:
explicit ConvBnReluFusion(bool multigraph = true) : PatternProcessPass("conv_bn_relu_fusion", multigraph) {}
~ConvBnReluFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_RELU_FUSION_H_

+ 5
- 2
mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc View File

@@ -28,6 +28,7 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id(); MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id();
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
@@ -37,7 +38,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
save_graphs_path = "."; save_graphs_path = ".";
} }
if (save_graphs) { if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_common_before.ir";
std::string file_path =
save_graphs_path + "/hwopt_common_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph); DumpIR(file_path, kernel_graph);
} }
auto optimizer = std::make_shared<GraphOptimizer>(); auto optimizer = std::make_shared<GraphOptimizer>();
@@ -51,7 +53,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
(void)optimizer->Optimize(kernel_graph); (void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault(); kernel_graph->SetExecOrderByDefault();
if (save_graphs) { if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_common_after.ir";
std::string file_path =
save_graphs_path + "/hwopt_common_after_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph); DumpIR(file_path, kernel_graph);
} }
} }


+ 10
- 0
mindspore/ccsrc/pre_activate/common/helper.cc View File

@@ -45,6 +45,7 @@ bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodeP
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return()); std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
std::map<AnfNodePtr, std::set<AnfNodePtr>> control_depend_map; std::map<AnfNodePtr, std::set<AnfNodePtr>> control_depend_map;
for (auto &nd : node_list) { for (auto &nd : node_list) {
MS_EXCEPTION_IF_NULL(nd);
if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) {
auto control_depend = nd->cast<CNodePtr>(); auto control_depend = nd->cast<CNodePtr>();
auto prior_node = control_depend->input(kControlDependPriorIndex); auto prior_node = control_depend->input(kControlDependPriorIndex);
@@ -157,6 +158,7 @@ const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const An
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);


auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputNum); auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputNum);
MS_EXCEPTION_IF_NULL(transop_cnode);
auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(kCastInputNum - 1), kDependInputNum); auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(kCastInputNum - 1), kDependInputNum);
auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputNum); auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputNum);
MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependInputNum - 1)); MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependInputNum - 1));
@@ -545,14 +547,22 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) {
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) { if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
auto a_node = utils::cast<AnfNodePtr>(a); auto a_node = utils::cast<AnfNodePtr>(a);
auto b_node = utils::cast<AnfNodePtr>(b); auto b_node = utils::cast<AnfNodePtr>(b);
MS_EXCEPTION_IF_NULL(a_node);
MS_EXCEPTION_IF_NULL(b_node);
if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) { if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
auto a_value_node = a_node->cast<ValueNodePtr>(); auto a_value_node = a_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(a_value_node);
auto a_value = a_value_node->value(); auto a_value = a_value_node->value();
MS_EXCEPTION_IF_NULL(a_value);
auto a_prim = a_value->cast<PrimitivePtr>(); auto a_prim = a_value->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(a_prim);


auto b_value_node = b_node->cast<ValueNodePtr>(); auto b_value_node = b_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(b_value_node);
auto b_value = b_value_node->value(); auto b_value = b_value_node->value();
MS_EXCEPTION_IF_NULL(b_value);
auto b_prim = b_value->cast<PrimitivePtr>(); auto b_prim = b_value->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(b_prim);


return a_prim->name() == b_prim->name(); return a_prim->name() == b_prim->name();
} else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) { } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {


+ 0
- 77
tests/ut/cpp/pre_activate/ascend/ir_fusion/conv_bn_fusion_test.cc View File

@@ -1,77 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/backend_common_test.h"
#include "operator/ops.h"
#include "debug/anf_ir_dump.h"
#include "common/py_func_graph_fetcher.h"
#include "pre_activate/common/optimizer.h"
#include "pre_activate/common/pass_manager.h"
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"

#define private public
#define protected public
#include "pre_activate/ascend/ir_fusion/conv_bn_fusion.h"
#undef private
#undef protected

namespace mindspore {
namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;

class TestHWConvBnFusion : public BackendCommon {
public:
TestHWConvBnFusion() : getPyFun_("gtest_input.pre_activate.ir_fusion_test", true) {}
~TestHWConvBnFusion() override = default;

UT::PyFuncGraphFetcher getPyFun_;
};

TEST_F(TestHWConvBnFusion, test_conv_bn_fusion) {
/*
* def before(x, y):
* conv_output = conv(x, y)
* bn_output = bn(conv_output)
* item0 = tuple_getitem(bn_output, 0)
* item1 = tuple_getitem(bn_output, 3)
* item2 = tuple_getitem(bn_output, 4)
* res = make_tuple(item0, item1, item2)
* return res
*/
getPyFun_.SetDoResolve(true);
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_conv_bn_fusion", "before");
std::vector<int> shp_x{32, 3, 224, 224};
std::vector<int> shp_w{64, 3, 7, 7};
std::vector<int> shp_b{64};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
auto w_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_w);
auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_b);
AbstractBasePtrList args_spec_list{x_abstract, w_abstract, b_abstract, b_abstract, b_abstract, b_abstract};
auto fg = GetKernelGraph(g, args_spec_list);

auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();
auto pass_manager = std::make_shared<opt::PassManager>();
auto conv_bn_fusion_pass = std::make_shared<opt::ConvBnFusion>();
pass_manager->AddPass(conv_bn_fusion_pass);
graph_optimizer->AddPassManager(pass_manager);
auto new_g = graph_optimizer->Optimize(fg);

FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_conv_bn_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_g));
}

} // namespace opt
} // namespace mindspore

+ 0
- 62
tests/ut/cpp/pre_activate/ascend/ir_fusion/conv_bn_relu_fusion_test.cc View File

@@ -1,62 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "session/anf_runtime_algorithm.h"
#include "debug/anf_ir_dump.h"
#include "kernel/kernel_build_info.h"

#define private public
#define protected public
#include "pre_activate/ascend/ir_fusion/conv_bn_relu_fusion.h"
#undef private
#undef protected

namespace mindspore {
namespace opt {
class TestHWConvBnReluFusion : public BackendCommon {
public:
TestHWConvBnReluFusion() : get_py_fun_("gtest_input.pre_activate.conv_bn_relu_fusion", true) {}
~TestHWConvBnReluFusion() override = default;

UT::PyFuncGraphFetcher get_py_fun_;
};

TEST_F(TestHWConvBnReluFusion, test_conv_bn_relu_fusion) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_conv_bn_relu_fusion", "before");
ASSERT_TRUE(g != nullptr);
std::vector<int> shp_x{32, 3, 224, 224};
std::vector<int> shp_w{64, 3, 7, 7};
std::vector<int> shp_b{64};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
auto w_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_w);
auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_b);
AbstractBasePtrList args_spec_list{x_abstract, w_abstract, b_abstract, b_abstract, b_abstract, b_abstract};
auto kernel_graph = GetKernelGraph(g, args_spec_list);

// do bn_grad_split_pass
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto pass = std::make_shared<opt::ConvBnReluFusion>();
pm->AddPass(pass);
optimizer->AddPassManager(pm);
auto new_graph = optimizer->Optimize(kernel_graph);

FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_conv_bn_relu_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
} // namespace opt
} // namespace mindspore

+ 0
- 71
tests/ut/cpp/python_input/gtest_input/pre_activate/conv_bn_relu_fusion.py View File

@@ -1,71 +0,0 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from mindspore.ops import Primitive
from mindspore.ops import operations as P

make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
conv = P.Conv2D(out_channel=64, kernel_size=7, mode=1, pad_mode="valid", pad=0, stride=1, dilation=1, group=1)
bn = P.FusedBatchNorm()
relu = P.ReLU()
conv_bn1 = Primitive('ConvBN1')
bn2_relu = Primitive('BN2Relu')


class FnDict:
def __init__(self):
self.fnDict = {}

def __call__(self, fn):
self.fnDict[fn.__name__] = fn

def __getitem__(self, name):
return self.fnDict[name]


def test_conv_bn_relu_fusion(tag):
""" test_conv_bn_relu_fusion """
fns = FnDict()

@fns
def before(x, w, scale, b, mean, variance):
conv_output = conv(x, w)
bn_output = bn(conv_output, scale, b, mean, variance)
item0 = tuple_getitem(bn_output, 0)
item1 = tuple_getitem(bn_output, 3)
item2 = tuple_getitem(bn_output, 4)
output = make_tuple(relu(item0), item1, item2)
res = tuple_getitem(output, 0)
return res

@fns
def after(x, w, scale, b, mean, variance):
conv_bn1_output = conv_bn1(x, w)
conv_item0 = tuple_getitem(conv_bn1_output, 0)
conv_item1 = tuple_getitem(conv_bn1_output, 1)
conv_item2 = tuple_getitem(conv_bn1_output, 2)
bn2_relu_output = bn2_relu(conv_item0, conv_item1, conv_item2, scale, b, mean, variance)
bn2_relu_item0 = tuple_getitem(bn2_relu_output, 0)
bn2_relu_item1 = tuple_getitem(bn2_relu_output, 1)
bn2_relu_item2 = tuple_getitem(bn2_relu_output, 2)
bn2_relu_item3 = tuple_getitem(bn2_relu_output, 3)
new_make_tuple = make_tuple(bn2_relu_item0, bn2_relu_item1, bn2_relu_item2, conv_item2, bn2_relu_item3)
item1 = tuple_getitem(new_make_tuple, 3)
item2 = tuple_getitem(new_make_tuple, 4)
output = make_tuple(bn2_relu_item0, item1, item2)
return make_tuple(tuple_getitem(output, 0))

return fns[tag]

+ 0
- 98
tests/ut/cpp/python_input/gtest_input/pre_activate/ir_fusion_test.py View File

@@ -49,104 +49,6 @@ class FnDict:
return self.fnDict[name] return self.fnDict[name]




def test_conv_bn_fusion(tag):
""" test_conv_bn_fusion """
fns = FnDict()

@fns
def before(x, w, scale, b, mean, variance):
conv_output = conv(x, w)
bn_output = bn(conv_output, scale, b, mean, variance)
item0 = tuple_getitem(bn_output, 0)
item1 = tuple_getitem(bn_output, 3)
item2 = tuple_getitem(bn_output, 4)
output = make_tuple(item0, item1, item2)
res = tuple_getitem(output, 0)
return res

@fns
def after(x, w, scale, b, mean, variance):
conv_bn1_output = conv_bn1(x, w)
conv_item0 = tuple_getitem(conv_bn1_output, 0)
conv_item1 = tuple_getitem(conv_bn1_output, 1)
conv_item2 = tuple_getitem(conv_bn1_output, 2)
bn2_output = fused_bn2(conv_item2, conv_item1, mean, variance)
bn2_item0 = tuple_getitem(bn2_output, 0)
bn2_item1 = tuple_getitem(bn2_output, 1)
bn2_item2 = tuple_getitem(bn2_output, 2)
bn3_output = fused_bn3(conv_item0, conv_item2, bn2_item0, scale, b)
output = make_tuple(bn3_output, bn2_item1, bn2_item2, conv_item2, bn2_item0)
item0 = tuple_getitem(output, 0)
item1 = tuple_getitem(output, 3)
item2 = tuple_getitem(output, 4)
new_output = make_tuple(item0, item1, item2)
return make_tuple(tuple_getitem(new_output, 0))

return fns[tag]


def test_conv_bn_add_relu_fusion(tag):
""" test_conv_bn_add_relu_fusion """
fns = FnDict()

@fns
def before(x, w, scale, b, mean, variance, y):
conv_output = conv(x, w)
bn_output = bn(conv_output, scale, b, mean, variance)
item0 = tuple_getitem(bn_output, 0)
s = add(item0, y)
res = relu(s)
return res

@fns
def after(x, w, scale, b, mean, variance, y):
conv_bn1_output = conv_bn1(x, w)
conv_item0 = tuple_getitem(conv_bn1_output, 0)
conv_item1 = tuple_getitem(conv_bn1_output, 1)
conv_item2 = tuple_getitem(conv_bn1_output, 2)
bn2_add_relu_output = bn2_add_relu(conv_item0, conv_item1, conv_item2, y, scale, b, mean, variance)
bn2_add_relu_item0 = tuple_getitem(bn2_add_relu_output, 0)
res = make_tuple(bn2_add_relu_item0)
return res

return fns[tag]


def test_conv_bn_relu_fusion(tag):
""" test_conv_bn_relu_fusion """
fns = FnDict()

@fns
def before(x, w, scale, b, mean, variance):
conv_output = conv(x, w)
bn_output = bn(conv_output, scale, b, mean, variance)
item0 = tuple_getitem(bn_output, 0)
item1 = tuple_getitem(bn_output, 3)
item2 = tuple_getitem(bn_output, 4)
output = make_tuple(relu(item0), item1, item2)
res = tuple_getitem(output, 0)
return res

@fns
def after(x, w, scale, b, mean, variance):
conv_bn1_output = conv_bn1(x, w)
conv_item0 = tuple_getitem(conv_bn1_output, 0)
conv_item1 = tuple_getitem(conv_bn1_output, 1)
conv_item2 = tuple_getitem(conv_bn1_output, 2)
bn2_relu_output = bn2_relu(conv_item0, conv_item1, conv_item2, scale, b, mean, variance)
bn2_relu_item0 = tuple_getitem(bn2_relu_output, 0)
bn2_relu_item1 = tuple_getitem(bn2_relu_output, 1)
bn2_relu_item2 = tuple_getitem(bn2_relu_output, 2)
bn2_relu_item3 = tuple_getitem(bn2_relu_output, 3)
new_make_tuple = make_tuple(bn2_relu_item0, bn2_relu_item1, bn2_relu_item2, conv_item2, bn2_relu_item3)
item1 = tuple_getitem(new_make_tuple, 3)
item2 = tuple_getitem(new_make_tuple, 4)
output = make_tuple(bn2_relu_item0, item1, item2)
return make_tuple(tuple_getitem(output, 0))

return fns[tag]


def test_bn_split(tag): def test_bn_split(tag):
""" test_split_bn_fusion """ """ test_split_bn_fusion """
fns = FnDict() fns = FnDict()


Loading…
Cancel
Save