|
|
@@ -16,6 +16,7 @@ |
|
|
|
|
|
|
|
|
#include "tools/anf_exporter/anf_exporter.h" |
|
|
#include "tools/anf_exporter/anf_exporter.h" |
|
|
|
|
|
|
|
|
|
|
|
#include <list> |
|
|
#include <memory> |
|
|
#include <memory> |
|
|
#include <string> |
|
|
#include <string> |
|
|
#include <utility> |
|
|
#include <utility> |
|
|
@@ -32,6 +33,45 @@ |
|
|
#include "tools/common/graph_util.h" |
|
|
#include "tools/common/graph_util.h" |
|
|
|
|
|
|
|
|
namespace mindspore::lite { |
|
|
namespace mindspore::lite { |
|
|
|
|
|
namespace { |
|
|
|
|
|
std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) { |
|
|
|
|
|
auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1); |
|
|
|
|
|
auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> { |
|
|
|
|
|
std::vector<AnfNodePtr> vecs; |
|
|
|
|
|
if (node == nullptr) { |
|
|
|
|
|
return vecs; |
|
|
|
|
|
} |
|
|
|
|
|
if (node->isa<CNode>()) { |
|
|
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
|
|
// Check if free variables used. |
|
|
|
|
|
for (const auto &input : inputs) { |
|
|
|
|
|
auto input_fg = GetValueNode<FuncGraphPtr>(input); |
|
|
|
|
|
if (input_fg) { |
|
|
|
|
|
for (auto &fv : input_fg->free_variables_nodes()) { |
|
|
|
|
|
if (fv->func_graph() == fg && fg->nodes().contains(fv)) { |
|
|
|
|
|
vecs.push_back(fv); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
(void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); |
|
|
|
|
|
} |
|
|
|
|
|
return vecs; |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
std::list<CNodePtr> cnodes; |
|
|
|
|
|
auto nodes = TopoSort(fg->get_return(), succ_include_fv, BelongSameGraph); |
|
|
|
|
|
for (const auto &node : nodes) { |
|
|
|
|
|
auto cnode = dyn_cast<CNode>(node); |
|
|
|
|
|
if (cnode) { |
|
|
|
|
|
cnodes.push_back(cnode); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return cnodes; |
|
|
|
|
|
} |
|
|
|
|
|
} // namespace |
|
|
|
|
|
|
|
|
void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { |
|
|
void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { |
|
|
bool has_make_tuple = false; |
|
|
bool has_make_tuple = false; |
|
|
std::vector<AnfNodePtr> inputs; |
|
|
std::vector<AnfNodePtr> inputs; |
|
|
@@ -220,7 +260,7 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc |
|
|
const size_t &subgraph_index, const bool &keep_graph, const bool ©_primitive, |
|
|
const size_t &subgraph_index, const bool &keep_graph, const bool ©_primitive, |
|
|
const std::unique_ptr<schema::SubGraphT> &sub_graphT) { |
|
|
const std::unique_ptr<schema::SubGraphT> &sub_graphT) { |
|
|
int ret = RET_OK; |
|
|
int ret = RET_OK; |
|
|
auto cnodes = func_graph->GetOrderedCnodes(); |
|
|
|
|
|
|
|
|
auto cnodes = GetOrderedCNodes(func_graph); |
|
|
for (const auto &cnode : cnodes) { |
|
|
for (const auto &cnode : cnodes) { |
|
|
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); |
|
|
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); |
|
|
if (primitive_c == nullptr) { |
|
|
if (primitive_c == nullptr) { |
|
|
|