diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 24d32a424e..0a573f6bfb 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -298,10 +298,6 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap MS_LOG(ERROR) << "config should be specified"; return nullptr; } - if (old_graph->has_flag("HasTransformed")) { - old_graph->set_flag("HasTransformed", false); - return old_graph; - } auto status = RunPrecedingPass(old_graph, *config); if (status != RET_OK) { @@ -356,43 +352,44 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap return new_graph; } -STATUS AnfTransform::GetAllFuncGraph(const FuncGraphPtr &main_graph, FuncGraphVector *subgraphs, - std::vector *vnodes) { - auto nodes = TopoSort(main_graph->get_return()); +void AnfTransform::GetAllFuncGraph(const FuncGraphPtr &func_graph) { + if (func_graphs_.find(func_graph) == func_graphs_.end()) { + func_graphs_.insert(func_graph); + } else { + return; + } + + auto nodes = func_graph->nodes(); for (auto &node : nodes) { - auto fg = GetValueNode(node); - if (fg) { - vnodes->push_back(utils::cast(node)); - subgraphs->push_back(fg); + if (IsValueNode(node)) { + auto new_fg = (node->cast()->value())->cast(); + GetAllFuncGraph(new_fg); + } + if (utils::isa(node)) { + auto cnode = node->cast(); + for (auto &input : cnode->inputs()) { + if (input->isa()) { + if (IsValueNode(input)) { + auto new_fg = (input->cast()->value())->cast(); + GetAllFuncGraph(new_fg); + } + } + } } } - return RET_OK; } FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) { - // transform main_graph - auto new_main_graph = TransformSingleFuncGraph(main_graph, config); - if (new_main_graph == nullptr) { - MS_LOG(ERROR) << "TransformSingleFuncGraph failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); - return nullptr; - } + GetAllFuncGraph(main_graph); - // transform sub_graph - FuncGraphVector subgraphs{}; - std::vector vnodes{}; - int ret = GetAllFuncGraph(main_graph, &subgraphs, &vnodes); - if (ret != RET_OK) { - MS_LOG(ERROR) << "GetAllFuncGraph failed " << ret; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); - return nullptr; - } - for (size_t i = 0; i < subgraphs.size(); i++) { - auto new_graph = Transform(subgraphs.at(i), config); - new_graph->set_flag("HasTransformed", true); - vnodes.at(i)->set_value(new_graph); + for (auto &fg : func_graphs_) { + auto new_main_graph = TransformSingleFuncGraph(fg, config); + if (new_main_graph == nullptr) { + MS_LOG(ERROR) << "TransformSingleFuncGraph failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); + return nullptr; + } } - - return new_main_graph; + return main_graph; } } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/anf_transform.h b/mindspore/lite/tools/converter/anf_transform.h index 2970d69da3..9ddb4f4565 100644 --- a/mindspore/lite/tools/converter/anf_transform.h +++ b/mindspore/lite/tools/converter/anf_transform.h @@ -19,6 +19,7 @@ #include #include +#include #include "backend/optimizer/common/optimizer.h" #include "schema/inner/model_generated.h" #include "tools/common/storage.h" @@ -38,8 +39,6 @@ class AnfTransform { private: std::unique_ptr m_quantizer_ = nullptr; - STATUS GetAllFuncGraph(const FuncGraphPtr &main_graph, FuncGraphVector *subgraphs, std::vector *vnodes); - FuncGraphPtr TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr); static int AddFusionPass(const std::shared_ptr &optimizer, const converter::Flags *config); @@ -61,6 +60,10 @@ class AnfTransform { static int RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config); int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph); + + void GetAllFuncGraph(const FuncGraphPtr &func_graph); + + std::set func_graphs_{}; }; } // namespace lite } // namespace mindspore