|
|
|
@@ -60,19 +60,36 @@ STATUS SingleSwitchPass::DoubleSwitchOutput() { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
void SingleSwitchPass::DoubleIdx(uint32_t *idx) { |
|
|
|
auto iter = std::find(switch_node_->outputIndex.begin(), switch_node_->outputIndex.end(), *idx); |
|
|
|
if (iter != switch_node_->outputIndex.end()) { |
|
|
|
int pos = iter - switch_node_->outputIndex.begin(); |
|
|
|
*idx = switch_node_->outputIndex.at(pos + switch_node_->outputIndex.size() / 2); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
STATUS SingleSwitchPass::UpdateSwitchUser() { |
|
|
|
std::vector<CNodeT *> switch_users; |
|
|
|
for (auto &node_idx : graph_->subGraph.at(this_subgraph_index_)->nodeIndices) { |
|
|
|
auto &node = graph_->nodes.at(node_idx); |
|
|
|
for (auto &idx : node->inputIndex) { |
|
|
|
auto iter = std::find(switch_node_->outputIndex.begin(), switch_node_->outputIndex.end(), idx); |
|
|
|
if (iter != switch_node_->outputIndex.end()) { |
|
|
|
if (IsContain(switch_node_->outputIndex, idx)) { |
|
|
|
switch_users.push_back(node.get()); |
|
|
|
int pos = iter - switch_node_->outputIndex.begin(); |
|
|
|
idx = switch_node_->outputIndex.at(pos + switch_node_->outputIndex.size() / 2); |
|
|
|
} |
|
|
|
DoubleIdx(&idx); |
|
|
|
} |
|
|
|
} |
|
|
|
// update graph switch user |
|
|
|
for (auto &subgraph : graph_->subGraph) { |
|
|
|
for (auto &idx : subgraph->outputIndices) { |
|
|
|
DoubleIdx(&idx); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &idx : graph_->outputIndex) { |
|
|
|
DoubleIdx(&idx); |
|
|
|
} |
|
|
|
|
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -307,7 +324,7 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem |
|
|
|
// get parameter input index k. subgraph name + “_input_" + "k" |
|
|
|
auto pos = subgraph->name.size() + sizeof("_input_"); |
|
|
|
auto pos2 = tensor->name.find('_', pos); |
|
|
|
auto idx_str = tensor->name.substr(pos - 1, pos2); |
|
|
|
auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1); |
|
|
|
partial_idx = std::stoi(idx_str); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -315,7 +332,7 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem |
|
|
|
// get parameter input index k. subgraph name + “_output_" + "k" |
|
|
|
auto pos = subgraph->name.size() + sizeof("_output_"); |
|
|
|
auto pos2 = tensor->name.find('_', pos); |
|
|
|
auto idx_str = tensor->name.substr(pos - 1, pos2); |
|
|
|
auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1); |
|
|
|
partial_idx = std::stoi(idx_str); |
|
|
|
} |
|
|
|
|
|
|
|
|