diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 757b0c988d..7c5cb813b9 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -61,197 +61,185 @@ void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _ int GraphDefTransform::Transform(const converter::Flags &ctx) { STATUS status; - if (ctx.fmk != converter::FmkType_TF) { - { - auto old_nodes = GetGraphNodes(); - Optimizer unusedOpRemoveOptimizer; - unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); - if (!ctx.trainModel) { - unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass()); - } - unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); - unusedOpRemoveOptimizer.AddPass(new SubgraphNodePass(old_nodes)); - status = unusedOpRemoveOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; - return status; - } + { + auto old_nodes = GetGraphNodes(); + Optimizer unusedOpRemoveOptimizer; + unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); + if (!ctx.trainModel) { + unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass()); } - // topological sorting - { - Optimizer topologicalOptimizer; - topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); - status = topologicalOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; - return status; - } + unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); + unusedOpRemoveOptimizer.AddPass(new SubgraphNodePass(old_nodes)); + status = unusedOpRemoveOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; + return status; } + } - // generate and infer quant parameters - { - Optimizer inferQuantParamPass; - inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass()); - inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass()); - status = inferQuantParamPass.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; - return status; - } + // generate and infer quant parameters + { + Optimizer inferQuantParamPass; + inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass()); + inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass()); + status = inferQuantParamPass.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; + return status; } + } - // postconvert pass - { - // init old node indecies - auto old_nodes = GetGraphNodes(); - Optimizer fusionOptimizer; - if (!ctx.trainModel) { - auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); - if (batch_norm_scale_pass == nullptr) { - MS_LOG(ERROR) << "new batch_norm_scale_pass failed."; - return RET_ERROR; - } - batch_norm_scale_pass->SetFmk(ctx.fmk); - fusionOptimizer.AddPass(batch_norm_scale_pass); - } - fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes)); - status = fusionOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed"; - return status; + // postconvert pass + { + // init old node indecies + auto old_nodes = GetGraphNodes(); + Optimizer fusionOptimizer; + if (!ctx.trainModel) { + auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); + if (batch_norm_scale_pass == nullptr) { + MS_LOG(ERROR) << "new batch_norm_scale_pass failed."; + return RET_ERROR; } + batch_norm_scale_pass->SetFmk(ctx.fmk); + fusionOptimizer.AddPass(batch_norm_scale_pass); } + fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes)); + status = fusionOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed"; + return status; + } + } + if (ctx.fmk != converter::FmkType_TF) { // format transform - { - // init old node indecies - auto old_nodes = GetGraphNodes(); + // init old node indecies + auto old_nodes = GetGraphNodes(); - Optimizer formatTransOptimizer; - auto formatTransPass = new (std::nothrow) FormatTransPass(); - if (formatTransPass == nullptr) { - MS_LOG(ERROR) << "new formatTransPass failed"; - return RET_MEMORY_FAILED; - } - formatTransPass->SetQuantType(ctx.quantType); - formatTransPass->SetFmk(ctx.fmk); - formatTransOptimizer.AddPass(formatTransPass); - formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); - formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); - formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass()); - status = formatTransOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { - MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; - return status; - } + Optimizer formatTransOptimizer; + auto formatTransPass = new (std::nothrow) FormatTransPass(); + if (formatTransPass == nullptr) { + MS_LOG(ERROR) << "new formatTransPass failed"; + return RET_MEMORY_FAILED; + } + formatTransPass->SetQuantType(ctx.quantType); + formatTransPass->SetFmk(ctx.fmk); + formatTransOptimizer.AddPass(formatTransPass); + formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); + formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass()); + status = formatTransOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { + MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; + return status; } + } - { - // init old node indecies - auto old_nodes = GetGraphNodes(); - Optimizer formatTransOptimizer; - auto formatTransPass = new (std::nothrow) FormatTransPass(); - if (formatTransPass == nullptr) { - MS_LOG(ERROR) << "new formatTransPass failed"; - return RET_MEMORY_FAILED; - } - formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); - formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass()); - formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass()); - formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); + { + // init old node indecies + auto old_nodes = GetGraphNodes(); + Optimizer formatTransOptimizer; + auto formatTransPass = new (std::nothrow) FormatTransPass(); + if (formatTransPass == nullptr) { + MS_LOG(ERROR) << "new formatTransPass failed"; + return RET_MEMORY_FAILED; + } + formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); + formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass()); + formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass()); + formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); + formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); + status = formatTransOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { + MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; + return status; + } + } + + { + // init old node indecies + auto old_nodes = GetGraphNodes(); + Optimizer formatTransOptimizer; + auto formatTransPass = new (std::nothrow) FormatTransPass(); + if (formatTransPass == nullptr) { + MS_LOG(ERROR) << "new formatTransPass failed"; + return RET_MEMORY_FAILED; + } + if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) { + formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass()); formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); - status = formatTransOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { - MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; - return status; - } } - - { - // init old node indecies - auto old_nodes = GetGraphNodes(); - Optimizer formatTransOptimizer; - auto formatTransPass = new (std::nothrow) FormatTransPass(); - if (formatTransPass == nullptr) { - MS_LOG(ERROR) << "new formatTransPass failed"; - return RET_MEMORY_FAILED; - } - if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) { - formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass()); - formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); - } - status = formatTransOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { - MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; - return status; - } + status = formatTransOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { + MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; + return status; } + } - { - // init old node indecies - auto old_nodes = GetGraphNodes(); - Optimizer fusionOptimizer; - fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass()); - fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - fusionOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); - status = fusionOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; - return status; - } + { + // init old node indecies + auto old_nodes = GetGraphNodes(); + Optimizer fusionOptimizer; + fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass()); + fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + fusionOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); + status = fusionOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; + return status; } + } - // do quantization - { - // init old node indecies - auto old_nodes = GetGraphNodes(); - Optimizer tensorQuantOptimizer; - tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); - tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass()); - tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); - tensorQuantOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); - status = tensorQuantOptimizer.Run(graphDefT); - if (status != RET_OK) { - MS_LOG(ERROR) << "DoQuantize failed!"; - return status; - } + // do quantization + if (ctx.fmk != converter::FmkType_TF) { + // init old node indecies + auto old_nodes = GetGraphNodes(); + Optimizer tensorQuantOptimizer; + tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass()); + tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); + tensorQuantOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); + status = tensorQuantOptimizer.Run(graphDefT); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoQuantize failed!"; + return status; } + } - // insert quantNode and deQuantNode - { - // init old node indecies - auto old_nodes = GetGraphNodes(); - Optimizer quantNodeOptimizer; - auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); - if (dTypeTransPass == nullptr) { - MS_LOG(ERROR) << "new dTypeTransPass failed"; - return RET_MEMORY_FAILED; - } - dTypeTransPass->SetInputDataDType(ctx.inputDataType); - dTypeTransPass->SetOutputDataDType(ctx.outputDataType); - quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); - quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); - quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass()); - status = quantNodeOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; - return status; - } - auto old_nodes2 = GetGraphNodes(); - quantNodeOptimizer.AddPass(dTypeTransPass); - quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); - quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); - quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2)); - status = quantNodeOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; - return status; - } + // insert quantNode and deQuantNode + if (ctx.fmk != converter::FmkType_TF) { + // init old node indecies + auto old_nodes = GetGraphNodes(); + Optimizer quantNodeOptimizer; + auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); + if (dTypeTransPass == nullptr) { + MS_LOG(ERROR) << "new dTypeTransPass failed"; + return RET_MEMORY_FAILED; + } + dTypeTransPass->SetInputDataDType(ctx.inputDataType); + dTypeTransPass->SetOutputDataDType(ctx.outputDataType); + quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); + quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass()); + status = quantNodeOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; + return status; + } + auto old_nodes2 = GetGraphNodes(); + quantNodeOptimizer.AddPass(dTypeTransPass); + quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); + quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); + quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2)); + status = quantNodeOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; + return status; } } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/topological_sort_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/topological_sort_pass.cc index 39b8096b38..bfb129eee9 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/topological_sort_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/topological_sort_pass.cc @@ -52,13 +52,12 @@ STATUS TopologicalSortPass::Run(schema::MetaGraphT *graph) { while (!op_queue.empty()) { auto &node = op_queue.front(); auto post_node_idxes = GetOutputNodeIdx(*graph, *(node.get())); + sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end()); for (auto post_node_idx : post_node_idxes) { if (IsContain(subgraph_node_indices, (unsigned int)(post_node_idx))) { auto &post_node = old_nodes.at(post_node_idx); // check if post_node is non-depended if (IsNodeNonDepend(post_node, sinked_tensor_idxes)) { - sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), post_node->outputIndex.begin(), - post_node->outputIndex.end()); op_queue.push(std::move(post_node)); } }