From 6fe923032e88f4f8a78ac6b98536202833193f3f Mon Sep 17 00:00:00 2001 From: xuanyue Date: Tue, 9 Feb 2021 15:58:16 +0800 Subject: [PATCH] test --- .../lite/tools/anf_exporter/anf_exporter.cc | 42 ++++++++++++++++++- mindspore/lite/tools/converter/converter.cc | 1 + 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 4654158871..c810d223f5 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -16,6 +16,7 @@ #include "tools/anf_exporter/anf_exporter.h" +#include #include #include #include @@ -32,6 +33,45 @@ #include "tools/common/graph_util.h" namespace mindspore::lite { +namespace { +std::list GetOrderedCNodes(const FuncGraphPtr fg) { + auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1); + auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector { + std::vector vecs; + if (node == nullptr) { + return vecs; + } + if (node->isa()) { + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + // Check if free variables used. + for (const auto &input : inputs) { + auto input_fg = GetValueNode(input); + if (input_fg) { + for (auto &fv : input_fg->free_variables_nodes()) { + if (fv->func_graph() == fg && fg->nodes().contains(fv)) { + vecs.push_back(fv); + } + } + } + } + (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); + } + return vecs; + }; + + std::list cnodes; + auto nodes = TopoSort(fg->get_return(), succ_include_fv, BelongSameGraph); + for (const auto &node : nodes) { + auto cnode = dyn_cast(node); + if (cnode) { + cnodes.push_back(cnode); + } + } + return cnodes; +} +} // namespace + void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { bool has_make_tuple = false; std::vector inputs; @@ -220,7 +260,7 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr &sub_graphT) { int ret = RET_OK; - auto cnodes = func_graph->GetOrderedCnodes(); + auto cnodes = GetOrderedCNodes(func_graph); for (const auto &cnode : cnodes) { auto primitive_c = GetValueNode>(cnode->input(0)); if (primitive_c == nullptr) { diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 09c9049bd3..69e11accc5 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -178,6 +178,7 @@ int RunConverter(int argc, const char **argv) { delete fb_graph; MS_LOG(INFO) << "CONVERT RESULT SUCCESS:" << status; std::cout << "CONVERT RESULT SUCCESS:" << status << std::endl; + return status; } } // namespace lite