|
|
|
@@ -251,9 +251,34 @@ STATUS SingleSwitchPass::InsertMerge() { |
|
|
|
second_partial_node_->inputIndex.assign(switch_node_->outputIndex.begin(), |
|
|
|
switch_node_->outputIndex.begin() + switch_node_->outputIndex.size() / 2); |
|
|
|
|
|
|
|
// skip tensor which is not any nodes' inputs to avoid body partial connect to merge input cnode |
|
|
|
std::vector<uint32_t> skip_input_tensors; |
|
|
|
for (auto input : const_input) { |
|
|
|
auto real_input = graph_->subGraph.at(second_subgraph_index_)->inputIndices.at(input); |
|
|
|
bool skip = true; |
|
|
|
for (auto &node : second_graph_nodes_) { |
|
|
|
if (IsContain(node->inputIndex, real_input)) { |
|
|
|
skip = false; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (skip) { |
|
|
|
auto &skip_tensor = graph_->allTensors.at(real_input); |
|
|
|
int partial_idx = GetSubgraphInputTensorIndex(graph_->subGraph.at(second_subgraph_index_), skip_tensor); |
|
|
|
skip_input_tensors.emplace_back(partial_idx); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// concat body output to merge input |
|
|
|
second_partial_node_->outputIndex.assign(merge_node->inputIndex.begin() + merge_node->inputIndex.size() / 2, |
|
|
|
merge_node->inputIndex.end()); |
|
|
|
second_partial_node_->outputIndex.clear(); |
|
|
|
for (uint32_t merge_right_input = 0; merge_right_input < merge_node->inputIndex.size() / 2; merge_right_input++) { |
|
|
|
if (!IsContain(skip_input_tensors, merge_right_input)) { |
|
|
|
second_partial_node_->outputIndex.emplace_back( |
|
|
|
merge_node->inputIndex.at(merge_node->inputIndex.size() / 2 + merge_right_input)); |
|
|
|
} else { |
|
|
|
second_partial_node_->outputIndex.emplace_back(UINT32_MAX); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
graph_->nodes.push_back(std::move(merge_node)); |
|
|
|
|
|
|
|
@@ -544,6 +569,13 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche |
|
|
|
[](std::pair<int, int> iter) { return iter.second; }); |
|
|
|
subgraph_outputs.assign(new_subgraph_outputs.begin(), new_subgraph_outputs.end()); |
|
|
|
|
|
|
|
// filter for -1 output index |
|
|
|
std::vector<uint32_t> new_partial_outputs; |
|
|
|
std::copy_if(partial_outputs.begin(), partial_outputs.end(), |
|
|
|
std::inserter(new_partial_outputs, new_partial_outputs.begin()), |
|
|
|
[](uint32_t output) { return output != UINT32_MAX; }); |
|
|
|
partial_node->outputIndex = new_partial_outputs; |
|
|
|
|
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
|