|
|
|
@@ -91,29 +91,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// postconvert pass |
|
|
|
{ |
|
|
|
// init old node indices |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
Optimizer fusionOptimizer; |
|
|
|
if (!ctx.trainModel) { |
|
|
|
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); |
|
|
|
if (batch_norm_scale_pass == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new batch_norm_scale_pass failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
batch_norm_scale_pass->SetFmk(ctx.fmk); |
|
|
|
fusionOptimizer.AddPass(batch_norm_scale_pass); |
|
|
|
} |
|
|
|
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); |
|
|
|
fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes)); |
|
|
|
status = fusionOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE) { |
|
|
|
MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
{ |
|
|
|
// format transform |
|
|
|
// init old node indices |
|
|
|
@@ -173,6 +150,29 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// postconvert pass |
|
|
|
{ |
|
|
|
// init old node indices |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
Optimizer fusionOptimizer; |
|
|
|
if (!ctx.trainModel) { |
|
|
|
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); |
|
|
|
if (batch_norm_scale_pass == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new batch_norm_scale_pass failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
batch_norm_scale_pass->SetFmk(ctx.fmk); |
|
|
|
fusionOptimizer.AddPass(batch_norm_scale_pass); |
|
|
|
} |
|
|
|
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); |
|
|
|
fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes)); |
|
|
|
status = fusionOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE) { |
|
|
|
MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
{ |
|
|
|
// init old node indices |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
|