|
|
@@ -115,6 +115,7 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) { |
|
|
auto &node = graph->nodes.at(i); |
|
|
auto &node = graph->nodes.at(i); |
|
|
std::vector<SubGraphT *> contain_node_input_subgraphs{}; |
|
|
std::vector<SubGraphT *> contain_node_input_subgraphs{}; |
|
|
std::vector<SubGraphT *> contain_node_output_subgraphs{}; |
|
|
std::vector<SubGraphT *> contain_node_output_subgraphs{}; |
|
|
|
|
|
std::vector<SubGraphT *> contain_subgraphs{}; |
|
|
for (auto &subgraph : graph->subGraph) { |
|
|
for (auto &subgraph : graph->subGraph) { |
|
|
std::set<uint32_t> tensors_indices{}; |
|
|
std::set<uint32_t> tensors_indices{}; |
|
|
int ret = GetSubgraphAllTensorIndices(subgraph, graph, &tensors_indices); |
|
|
int ret = GetSubgraphAllTensorIndices(subgraph, graph, &tensors_indices); |
|
|
@@ -129,26 +130,26 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) { |
|
|
contain_node_output_subgraphs.push_back(subgraph.get()); |
|
|
contain_node_output_subgraphs.push_back(subgraph.get()); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.size() == 1 && |
|
|
|
|
|
contain_node_output_subgraphs[0] != contain_node_input_subgraphs[0]) { |
|
|
|
|
|
MS_LOG(ERROR) << "not support single node index insert."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.size() == 1 && |
|
|
|
|
|
contain_node_output_subgraphs[0] == contain_node_input_subgraphs[0]) { |
|
|
|
|
|
|
|
|
std::set_intersection(contain_node_input_subgraphs.begin(), contain_node_input_subgraphs.end(), |
|
|
|
|
|
contain_node_output_subgraphs.begin(), contain_node_output_subgraphs.end(), |
|
|
|
|
|
inserter(contain_subgraphs, contain_subgraphs.begin())); |
|
|
|
|
|
if (contain_subgraphs.size() == 1) { |
|
|
IncreaseSubgraphNodeIndices(i, graph); |
|
|
IncreaseSubgraphNodeIndices(i, graph); |
|
|
contain_node_input_subgraphs[0]->nodeIndices.push_back(i); |
|
|
|
|
|
|
|
|
contain_subgraphs[0]->nodeIndices.push_back(i); |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
if (contain_node_input_subgraphs.size() == 1) { |
|
|
|
|
|
|
|
|
if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.empty()) { |
|
|
IncreaseSubgraphNodeIndices(i, graph); |
|
|
IncreaseSubgraphNodeIndices(i, graph); |
|
|
contain_node_input_subgraphs[0]->nodeIndices.push_back(i); |
|
|
contain_node_input_subgraphs[0]->nodeIndices.push_back(i); |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
if (contain_node_output_subgraphs.size() == 1) { |
|
|
|
|
|
|
|
|
if (contain_node_output_subgraphs.size() == 1 && contain_node_input_subgraphs.empty()) { |
|
|
IncreaseSubgraphNodeIndices(i, graph); |
|
|
IncreaseSubgraphNodeIndices(i, graph); |
|
|
contain_node_output_subgraphs[0]->nodeIndices.push_back(i); |
|
|
contain_node_output_subgraphs[0]->nodeIndices.push_back(i); |
|
|
continue; |
|
|
continue; |
|
|
|
|
|
} else { |
|
|
|
|
|
MS_LOG(ERROR) << "Not able to find which subgraph to insert node: " << node->name; |
|
|
|
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|