Browse Source

!10720 [MS][LITE][Develop]open some graph def transform

From: @mengyuanli
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
22a3587aa5
2 changed files with 159 additions and 172 deletions
  1. +158
    -170
      mindspore/lite/tools/converter/graphdef_transform.cc
  2. +1
    -2
      mindspore/lite/tools/converter/legacy_optimizer/graph/topological_sort_pass.cc

+ 158
- 170
mindspore/lite/tools/converter/graphdef_transform.cc View File

@@ -61,197 +61,185 @@ void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _

int GraphDefTransform::Transform(const converter::Flags &ctx) {
STATUS status;
if (ctx.fmk != converter::FmkType_TF) {
{
auto old_nodes = GetGraphNodes();
Optimizer unusedOpRemoveOptimizer;
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass());
if (!ctx.trainModel) {
unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass());
}
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass());
unusedOpRemoveOptimizer.AddPass(new SubgraphNodePass(old_nodes));
status = unusedOpRemoveOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed";
return status;
}
{
auto old_nodes = GetGraphNodes();
Optimizer unusedOpRemoveOptimizer;
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass());
if (!ctx.trainModel) {
unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass());
}
// topological sorting
{
Optimizer topologicalOptimizer;
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
status = topologicalOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass());
unusedOpRemoveOptimizer.AddPass(new SubgraphNodePass(old_nodes));
status = unusedOpRemoveOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed";
return status;
}
}

// generate and infer quant parameters
{
Optimizer inferQuantParamPass;
inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass());
inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass());
status = inferQuantParamPass.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
// generate and infer quant parameters
{
Optimizer inferQuantParamPass;
inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass());
inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass());
status = inferQuantParamPass.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
}

// postconvert pass
{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer fusionOptimizer;
if (!ctx.trainModel) {
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass();
if (batch_norm_scale_pass == nullptr) {
MS_LOG(ERROR) << "new batch_norm_scale_pass failed.";
return RET_ERROR;
}
batch_norm_scale_pass->SetFmk(ctx.fmk);
fusionOptimizer.AddPass(batch_norm_scale_pass);
}
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes));
status = fusionOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed";
return status;
// postconvert pass
{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer fusionOptimizer;
if (!ctx.trainModel) {
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass();
if (batch_norm_scale_pass == nullptr) {
MS_LOG(ERROR) << "new batch_norm_scale_pass failed.";
return RET_ERROR;
}
batch_norm_scale_pass->SetFmk(ctx.fmk);
fusionOptimizer.AddPass(batch_norm_scale_pass);
}
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes));
status = fusionOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed";
return status;
}
}

if (ctx.fmk != converter::FmkType_TF) {
// format transform
{
// init old node indecies
auto old_nodes = GetGraphNodes();
// init old node indecies
auto old_nodes = GetGraphNodes();

Optimizer formatTransOptimizer;
auto formatTransPass = new (std::nothrow) FormatTransPass();
if (formatTransPass == nullptr) {
MS_LOG(ERROR) << "new formatTransPass failed";
return RET_MEMORY_FAILED;
}
formatTransPass->SetQuantType(ctx.quantType);
formatTransPass->SetFmk(ctx.fmk);
formatTransOptimizer.AddPass(formatTransPass);
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass());
status = formatTransOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
return status;
}
Optimizer formatTransOptimizer;
auto formatTransPass = new (std::nothrow) FormatTransPass();
if (formatTransPass == nullptr) {
MS_LOG(ERROR) << "new formatTransPass failed";
return RET_MEMORY_FAILED;
}
formatTransPass->SetQuantType(ctx.quantType);
formatTransPass->SetFmk(ctx.fmk);
formatTransOptimizer.AddPass(formatTransPass);
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass());
status = formatTransOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
return status;
}
}

{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer formatTransOptimizer;
auto formatTransPass = new (std::nothrow) FormatTransPass();
if (formatTransPass == nullptr) {
MS_LOG(ERROR) << "new formatTransPass failed";
return RET_MEMORY_FAILED;
}
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass());
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer formatTransOptimizer;
auto formatTransPass = new (std::nothrow) FormatTransPass();
if (formatTransPass == nullptr) {
MS_LOG(ERROR) << "new formatTransPass failed";
return RET_MEMORY_FAILED;
}
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass());
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = formatTransOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
return status;
}
}

{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer formatTransOptimizer;
auto formatTransPass = new (std::nothrow) FormatTransPass();
if (formatTransPass == nullptr) {
MS_LOG(ERROR) << "new formatTransPass failed";
return RET_MEMORY_FAILED;
}
if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) {
formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = formatTransOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
return status;
}
}

{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer formatTransOptimizer;
auto formatTransPass = new (std::nothrow) FormatTransPass();
if (formatTransPass == nullptr) {
MS_LOG(ERROR) << "new formatTransPass failed";
return RET_MEMORY_FAILED;
}
if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) {
formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
}
status = formatTransOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
return status;
}
status = formatTransOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
return status;
}
}

{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer fusionOptimizer;
fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass());
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
fusionOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = fusionOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
return status;
}
{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer fusionOptimizer;
fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass());
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
fusionOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = fusionOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
return status;
}
}

// do quantization
{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer tensorQuantOptimizer;
tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass());
tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass());
tensorQuantOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = tensorQuantOptimizer.Run(graphDefT);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantize failed!";
return status;
}
// do quantization
if (ctx.fmk != converter::FmkType_TF) {
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer tensorQuantOptimizer;
tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass());
tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass());
tensorQuantOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = tensorQuantOptimizer.Run(graphDefT);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantize failed!";
return status;
}
}

// insert quantNode and deQuantNode
{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer quantNodeOptimizer;
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
if (dTypeTransPass == nullptr) {
MS_LOG(ERROR) << "new dTypeTransPass failed";
return RET_MEMORY_FAILED;
}
dTypeTransPass->SetInputDataDType(ctx.inputDataType);
dTypeTransPass->SetOutputDataDType(ctx.outputDataType);
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass());
status = quantNodeOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
return status;
}
auto old_nodes2 = GetGraphNodes();
quantNodeOptimizer.AddPass(dTypeTransPass);
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2));
status = quantNodeOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
return status;
}
// insert quantNode and deQuantNode
if (ctx.fmk != converter::FmkType_TF) {
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer quantNodeOptimizer;
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
if (dTypeTransPass == nullptr) {
MS_LOG(ERROR) << "new dTypeTransPass failed";
return RET_MEMORY_FAILED;
}
dTypeTransPass->SetInputDataDType(ctx.inputDataType);
dTypeTransPass->SetOutputDataDType(ctx.outputDataType);
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass());
status = quantNodeOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
return status;
}
auto old_nodes2 = GetGraphNodes();
quantNodeOptimizer.AddPass(dTypeTransPass);
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2));
status = quantNodeOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
return status;
}
}



+ 1
- 2
mindspore/lite/tools/converter/legacy_optimizer/graph/topological_sort_pass.cc View File

@@ -52,13 +52,12 @@ STATUS TopologicalSortPass::Run(schema::MetaGraphT *graph) {
while (!op_queue.empty()) {
auto &node = op_queue.front();
auto post_node_idxes = GetOutputNodeIdx(*graph, *(node.get()));
sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end());
for (auto post_node_idx : post_node_idxes) {
if (IsContain(subgraph_node_indices, (unsigned int)(post_node_idx))) {
auto &post_node = old_nodes.at(post_node_idx);
// check if post_node is non-depended
if (IsNodeNonDepend(post_node, sinked_tensor_idxes)) {
sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), post_node->outputIndex.begin(),
post_node->outputIndex.end());
op_queue.push(std::move(post_node));
}
}


Loading…
Cancel
Save