Browse Source

fix bug of multioutput byway fusion pass

tags/v0.5.0-beta
etone-chan 5 years ago
parent
commit
23ba6291cc
2 changed files with 2 additions and 2 deletions
  1. +0
    -2
      mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
  2. +2
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.cc

+ 0
- 2
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc View File

@@ -66,7 +66,6 @@
#include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h"
#include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h"
@@ -365,7 +364,6 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap
auto ub_fusion_pm = std::make_shared<PassManager>("ub_fusion_pm");
ub_fusion_pm->AddPass(std::make_shared<Conv2DBackpropEltwiseEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<Conv2DBackpropEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<StridedReadConvStridedWriteFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<ConvBnReduceFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<ConvSingleInFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseFusionPass>(fusion_id_allocator));


+ 2
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.cc View File

@@ -36,6 +36,8 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const
std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1);
if (CheckMultiOutputEltWiseNode(manager.get(), eltwise_input)) {
std::vector<int> output_used_num{SizeToInt(manager->node_users()[eltwise_input].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input);
(void)record.insert(eltwise_input);
auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);


Loading…
Cancel
Save