Browse Source

!11870 [MSLITE] adjust bn to scale pass

From: @zhengjun10
Reviewed-by: @zhang_xue_tong,@hangangqiang
Signed-off-by: @zhang_xue_tong
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
153b191c87
2 changed files with 25 additions and 26 deletions
  1. +23
    -23
      mindspore/lite/tools/converter/graphdef_transform.cc
  2. +2
    -3
      mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.cc

+ 23
- 23
mindspore/lite/tools/converter/graphdef_transform.cc View File

@@ -91,29 +91,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
} }
} }


// postconvert pass
{
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer fusionOptimizer;
if (!ctx.trainModel) {
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass();
if (batch_norm_scale_pass == nullptr) {
MS_LOG(ERROR) << "new batch_norm_scale_pass failed.";
return RET_ERROR;
}
batch_norm_scale_pass->SetFmk(ctx.fmk);
fusionOptimizer.AddPass(batch_norm_scale_pass);
}
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes));
status = fusionOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed";
return status;
}
}

{ {
// format transform // format transform
// init old node indices // init old node indices
@@ -173,6 +150,29 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
} }
} }


// postconvert pass
{
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer fusionOptimizer;
if (!ctx.trainModel) {
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass();
if (batch_norm_scale_pass == nullptr) {
MS_LOG(ERROR) << "new batch_norm_scale_pass failed.";
return RET_ERROR;
}
batch_norm_scale_pass->SetFmk(ctx.fmk);
fusionOptimizer.AddPass(batch_norm_scale_pass);
}
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes));
status = fusionOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed";
return status;
}
}

{ {
// init old node indices // init old node indices
auto old_nodes = GetGraphNodes(); auto old_nodes = GetGraphNodes();


+ 2
- 3
mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.cc View File

@@ -74,9 +74,8 @@ STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std:
MS_LOG(ERROR) << "new scaleParam failed"; MS_LOG(ERROR) << "new scaleParam failed";
return RET_ERROR; return RET_ERROR;
} }
int32_t axis =
(graph->allTensors.at(bnNode->inputIndex.at(1))->format == Format_NHWC) ? (int32_t)NHWC_C : (int32_t)NCHW_C;
scaleParam->axis = axis;
// after fusion bn must NHWC
scaleParam->axis = -1;
bnNode->primitive->value.value = scaleParam.release(); bnNode->primitive->value.value = scaleParam.release();
auto input0 = bnNode->inputIndex.at(0); auto input0 = bnNode->inputIndex.at(0);
bnNode->inputIndex.clear(); bnNode->inputIndex.clear();


Loading…
Cancel
Save