Browse Source

fix when Batchnorm's output is 0,1,2,4, fission doesn't work

tags/v0.5.0-beta
huanghui 5 years ago
parent
commit
36d1aadf1c
2 changed files with 3 additions and 3 deletions
  1. +1
    -1
      mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
  2. +2
    -2
      mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc

+ 1
- 1
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc View File

@@ -99,6 +99,7 @@ namespace {
void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
MS_EXCEPTION_IF_NULL(ir_fusion_pm);
ir_fusion_pm->AddPass(std::make_shared<BatchNormBertFission>());
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
@@ -225,7 +226,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
}
ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());


+ 2
- 2
mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc View File

@@ -24,7 +24,7 @@ namespace mindspore {
namespace opt {
namespace {
const std::vector<int> kOutputIndex{0, 1, 2, 3, 4};
constexpr size_t kBatchNormRealOutputNum = 5;
constexpr size_t kBatchNormLeastOutputNum = 1;
constexpr size_t kBatchNormRealInputNum = 3;

bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
@@ -56,7 +56,7 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s
bn_outputs->push_back(output);
output_num++;
}
return output_num == kBatchNormRealOutputNum;
return output_num > kBatchNormLeastOutputNum;
}

AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) {


Loading…
Cancel
Save