Browse Source

fix when Batchnorm's output is 0,1,2,4, fission doesn't work

tags/v0.5.0-beta
huanghui 5 years ago
parent
commit
36d1aadf1c
2 changed files with 3 additions and 3 deletions
  1. +1
    -1
      mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
  2. +2
    -2
      mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc

+ 1
- 1
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc View File

@@ -99,6 +99,7 @@ namespace {
void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
MS_EXCEPTION_IF_NULL(ir_fusion_pm); MS_EXCEPTION_IF_NULL(ir_fusion_pm);
ir_fusion_pm->AddPass(std::make_shared<BatchNormBertFission>()); ir_fusion_pm->AddPass(std::make_shared<BatchNormBertFission>());
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>()); ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>()); ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>()); ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
@@ -225,7 +226,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>()); ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>()); ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>()); ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
} }
ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>()); ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());


+ 2
- 2
mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc View File

@@ -24,7 +24,7 @@ namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
const std::vector<int> kOutputIndex{0, 1, 2, 3, 4}; const std::vector<int> kOutputIndex{0, 1, 2, 3, 4};
constexpr size_t kBatchNormRealOutputNum = 5;
constexpr size_t kBatchNormLeastOutputNum = 1;
constexpr size_t kBatchNormRealInputNum = 3; constexpr size_t kBatchNormRealInputNum = 3;


bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) { bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
@@ -56,7 +56,7 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s
bn_outputs->push_back(output); bn_outputs->push_back(output);
output_num++; output_num++;
} }
return output_num == kBatchNormRealOutputNum;
return output_num > kBatchNormLeastOutputNum;
} }


AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) {


Loading…
Cancel
Save