|
|
|
@@ -61,197 +61,185 @@ void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _ |
|
|
|
|
|
|
|
int GraphDefTransform::Transform(const converter::Flags &ctx) { |
|
|
|
STATUS status; |
|
|
|
if (ctx.fmk != converter::FmkType_TF) { |
|
|
|
{ |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
Optimizer unusedOpRemoveOptimizer; |
|
|
|
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); |
|
|
|
if (!ctx.trainModel) { |
|
|
|
unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass()); |
|
|
|
} |
|
|
|
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); |
|
|
|
unusedOpRemoveOptimizer.AddPass(new SubgraphNodePass(old_nodes)); |
|
|
|
status = unusedOpRemoveOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE) { |
|
|
|
MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
{ |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
Optimizer unusedOpRemoveOptimizer; |
|
|
|
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); |
|
|
|
if (!ctx.trainModel) { |
|
|
|
unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass()); |
|
|
|
} |
|
|
|
// topological sorting |
|
|
|
{ |
|
|
|
Optimizer topologicalOptimizer; |
|
|
|
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); |
|
|
|
status = topologicalOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE) { |
|
|
|
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); |
|
|
|
unusedOpRemoveOptimizer.AddPass(new SubgraphNodePass(old_nodes)); |
|
|
|
status = unusedOpRemoveOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE) { |
|
|
|
MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// generate and infer quant parameters |
|
|
|
{ |
|
|
|
Optimizer inferQuantParamPass; |
|
|
|
inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass()); |
|
|
|
inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass()); |
|
|
|
status = inferQuantParamPass.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE) { |
|
|
|
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
// generate and infer quant parameters |
|
|
|
{ |
|
|
|
Optimizer inferQuantParamPass; |
|
|
|
inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass()); |
|
|
|
inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass()); |
|
|
|
status = inferQuantParamPass.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE) { |
|
|
|
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// postconvert pass |
|
|
|
{ |
|
|
|
// init old node indecies |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
Optimizer fusionOptimizer; |
|
|
|
if (!ctx.trainModel) { |
|
|
|
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); |
|
|
|
if (batch_norm_scale_pass == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new batch_norm_scale_pass failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
batch_norm_scale_pass->SetFmk(ctx.fmk); |
|
|
|
fusionOptimizer.AddPass(batch_norm_scale_pass); |
|
|
|
} |
|
|
|
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); |
|
|
|
fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes)); |
|
|
|
status = fusionOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE) { |
|
|
|
MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed"; |
|
|
|
return status; |
|
|
|
// postconvert pass |
|
|
|
{ |
|
|
|
// init old node indecies |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
Optimizer fusionOptimizer; |
|
|
|
if (!ctx.trainModel) { |
|
|
|
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); |
|
|
|
if (batch_norm_scale_pass == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new batch_norm_scale_pass failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
batch_norm_scale_pass->SetFmk(ctx.fmk); |
|
|
|
fusionOptimizer.AddPass(batch_norm_scale_pass); |
|
|
|
} |
|
|
|
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); |
|
|
|
fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes)); |
|
|
|
status = fusionOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE) { |
|
|
|
MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (ctx.fmk != converter::FmkType_TF) { |
|
|
|
// format transform |
|
|
|
{ |
|
|
|
// init old node indecies |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
// init old node indecies |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
|
|
|
|
Optimizer formatTransOptimizer; |
|
|
|
auto formatTransPass = new (std::nothrow) FormatTransPass(); |
|
|
|
if (formatTransPass == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new formatTransPass failed"; |
|
|
|
return RET_MEMORY_FAILED; |
|
|
|
} |
|
|
|
formatTransPass->SetQuantType(ctx.quantType); |
|
|
|
formatTransPass->SetFmk(ctx.fmk); |
|
|
|
formatTransOptimizer.AddPass(formatTransPass); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass()); |
|
|
|
status = formatTransOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { |
|
|
|
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
Optimizer formatTransOptimizer; |
|
|
|
auto formatTransPass = new (std::nothrow) FormatTransPass(); |
|
|
|
if (formatTransPass == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new formatTransPass failed"; |
|
|
|
return RET_MEMORY_FAILED; |
|
|
|
} |
|
|
|
formatTransPass->SetQuantType(ctx.quantType); |
|
|
|
formatTransPass->SetFmk(ctx.fmk); |
|
|
|
formatTransOptimizer.AddPass(formatTransPass); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass()); |
|
|
|
status = formatTransOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { |
|
|
|
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
{ |
|
|
|
// init old node indecies |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
Optimizer formatTransOptimizer; |
|
|
|
auto formatTransPass = new (std::nothrow) FormatTransPass(); |
|
|
|
if (formatTransPass == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new formatTransPass failed"; |
|
|
|
return RET_MEMORY_FAILED; |
|
|
|
} |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass()); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass()); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); |
|
|
|
{ |
|
|
|
// init old node indecies |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
Optimizer formatTransOptimizer; |
|
|
|
auto formatTransPass = new (std::nothrow) FormatTransPass(); |
|
|
|
if (formatTransPass == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new formatTransPass failed"; |
|
|
|
return RET_MEMORY_FAILED; |
|
|
|
} |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass()); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass()); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); |
|
|
|
status = formatTransOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { |
|
|
|
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
{ |
|
|
|
// init old node indecies |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
Optimizer formatTransOptimizer; |
|
|
|
auto formatTransPass = new (std::nothrow) FormatTransPass(); |
|
|
|
if (formatTransPass == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new formatTransPass failed"; |
|
|
|
return RET_MEMORY_FAILED; |
|
|
|
} |
|
|
|
if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) { |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass()); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); |
|
|
|
status = formatTransOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { |
|
|
|
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
{ |
|
|
|
// init old node indecies |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
Optimizer formatTransOptimizer; |
|
|
|
auto formatTransPass = new (std::nothrow) FormatTransPass(); |
|
|
|
if (formatTransPass == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new formatTransPass failed"; |
|
|
|
return RET_MEMORY_FAILED; |
|
|
|
} |
|
|
|
if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) { |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass()); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); |
|
|
|
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); |
|
|
|
} |
|
|
|
status = formatTransOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { |
|
|
|
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
status = formatTransOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { |
|
|
|
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
{ |
|
|
|
// init old node indecies |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
Optimizer fusionOptimizer; |
|
|
|
fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass()); |
|
|
|
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); |
|
|
|
fusionOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); |
|
|
|
status = fusionOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE) { |
|
|
|
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
{ |
|
|
|
// init old node indecies |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
Optimizer fusionOptimizer; |
|
|
|
fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass()); |
|
|
|
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); |
|
|
|
fusionOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); |
|
|
|
status = fusionOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE) { |
|
|
|
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// do quantization |
|
|
|
{ |
|
|
|
// init old node indecies |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
Optimizer tensorQuantOptimizer; |
|
|
|
tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); |
|
|
|
tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass()); |
|
|
|
tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); |
|
|
|
tensorQuantOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); |
|
|
|
status = tensorQuantOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "DoQuantize failed!"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
// do quantization |
|
|
|
if (ctx.fmk != converter::FmkType_TF) { |
|
|
|
// init old node indecies |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
Optimizer tensorQuantOptimizer; |
|
|
|
tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); |
|
|
|
tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass()); |
|
|
|
tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); |
|
|
|
tensorQuantOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); |
|
|
|
status = tensorQuantOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "DoQuantize failed!"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// insert quantNode and deQuantNode |
|
|
|
{ |
|
|
|
// init old node indecies |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
Optimizer quantNodeOptimizer; |
|
|
|
auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); |
|
|
|
if (dTypeTransPass == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new dTypeTransPass failed"; |
|
|
|
return RET_MEMORY_FAILED; |
|
|
|
} |
|
|
|
dTypeTransPass->SetInputDataDType(ctx.inputDataType); |
|
|
|
dTypeTransPass->SetOutputDataDType(ctx.outputDataType); |
|
|
|
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); |
|
|
|
quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); |
|
|
|
quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass()); |
|
|
|
status = quantNodeOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE) { |
|
|
|
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
auto old_nodes2 = GetGraphNodes(); |
|
|
|
quantNodeOptimizer.AddPass(dTypeTransPass); |
|
|
|
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); |
|
|
|
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); |
|
|
|
quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); |
|
|
|
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2)); |
|
|
|
status = quantNodeOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE) { |
|
|
|
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
// insert quantNode and deQuantNode |
|
|
|
if (ctx.fmk != converter::FmkType_TF) { |
|
|
|
// init old node indecies |
|
|
|
auto old_nodes = GetGraphNodes(); |
|
|
|
Optimizer quantNodeOptimizer; |
|
|
|
auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); |
|
|
|
if (dTypeTransPass == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new dTypeTransPass failed"; |
|
|
|
return RET_MEMORY_FAILED; |
|
|
|
} |
|
|
|
dTypeTransPass->SetInputDataDType(ctx.inputDataType); |
|
|
|
dTypeTransPass->SetOutputDataDType(ctx.outputDataType); |
|
|
|
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); |
|
|
|
quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); |
|
|
|
quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass()); |
|
|
|
status = quantNodeOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE) { |
|
|
|
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
auto old_nodes2 = GetGraphNodes(); |
|
|
|
quantNodeOptimizer.AddPass(dTypeTransPass); |
|
|
|
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); |
|
|
|
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); |
|
|
|
quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); |
|
|
|
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2)); |
|
|
|
status = quantNodeOptimizer.Run(graphDefT); |
|
|
|
if (status != RET_OK && status != RET_NO_CHANGE) { |
|
|
|
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|