From 97323a3c9654cd2d41f9045226288e6828a75f38 Mon Sep 17 00:00:00 2001 From: zhengjun10 Date: Sat, 30 Jan 2021 10:51:19 +0800 Subject: [PATCH] adjust bn pass loc --- .../tools/converter/graphdef_transform.cc | 46 +++++++++---------- .../graph/batchnorm_convert_scale_pass.cc | 5 +- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index eeb552b7f5..d34654f068 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -91,29 +91,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } - // postconvert pass - { - // init old node indices - auto old_nodes = GetGraphNodes(); - Optimizer fusionOptimizer; - if (!ctx.trainModel) { - auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); - if (batch_norm_scale_pass == nullptr) { - MS_LOG(ERROR) << "new batch_norm_scale_pass failed."; - return RET_ERROR; - } - batch_norm_scale_pass->SetFmk(ctx.fmk); - fusionOptimizer.AddPass(batch_norm_scale_pass); - } - fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes)); - status = fusionOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed"; - return status; - } - } - { // format transform // init old node indices @@ -173,6 +150,29 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } + // postconvert pass + { + // init old node indices + auto old_nodes = GetGraphNodes(); + Optimizer fusionOptimizer; + if (!ctx.trainModel) { + auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); + if (batch_norm_scale_pass == nullptr) { + MS_LOG(ERROR) << "new batch_norm_scale_pass failed."; + return RET_ERROR; + } + batch_norm_scale_pass->SetFmk(ctx.fmk); + fusionOptimizer.AddPass(batch_norm_scale_pass); + } + fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes)); + status = fusionOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed"; + return status; + } + } + { // init old node indices auto old_nodes = GetGraphNodes(); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.cc index e17e3faa41..1a7ead0d78 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.cc @@ -74,9 +74,8 @@ STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std: MS_LOG(ERROR) << "new scaleParam failed"; return RET_ERROR; } - int32_t axis = - (graph->allTensors.at(bnNode->inputIndex.at(1))->format == Format_NHWC) ? (int32_t)NHWC_C : (int32_t)NCHW_C; - scaleParam->axis = axis; + // after fusion bn must NHWC + scaleParam->axis = -1; bnNode->primitive->value.value = scaleParam.release(); auto input0 = bnNode->inputIndex.at(0); bnNode->inputIndex.clear();