Browse Source

fix bug of switch pass

tags/v1.1.0
mengyuanli 5 years ago
parent
commit
a3df0ae160
2 changed files with 24 additions and 6 deletions
  1. +23
    -6
      mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc
  2. +1
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.h

+ 23
- 6
mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc View File

@@ -60,19 +60,36 @@ STATUS SingleSwitchPass::DoubleSwitchOutput() {
return RET_OK;
}

void SingleSwitchPass::DoubleIdx(uint32_t *idx) {
auto iter = std::find(switch_node_->outputIndex.begin(), switch_node_->outputIndex.end(), *idx);
if (iter != switch_node_->outputIndex.end()) {
int pos = iter - switch_node_->outputIndex.begin();
*idx = switch_node_->outputIndex.at(pos + switch_node_->outputIndex.size() / 2);
}
}

STATUS SingleSwitchPass::UpdateSwitchUser() {
std::vector<CNodeT *> switch_users;
for (auto &node_idx : graph_->subGraph.at(this_subgraph_index_)->nodeIndices) {
auto &node = graph_->nodes.at(node_idx);
for (auto &idx : node->inputIndex) {
auto iter = std::find(switch_node_->outputIndex.begin(), switch_node_->outputIndex.end(), idx);
if (iter != switch_node_->outputIndex.end()) {
if (IsContain(switch_node_->outputIndex, idx)) {
switch_users.push_back(node.get());
int pos = iter - switch_node_->outputIndex.begin();
idx = switch_node_->outputIndex.at(pos + switch_node_->outputIndex.size() / 2);
}
DoubleIdx(&idx);
}
}
// update graph switch user
for (auto &subgraph : graph_->subGraph) {
for (auto &idx : subgraph->outputIndices) {
DoubleIdx(&idx);
}
}

for (auto &idx : graph_->outputIndex) {
DoubleIdx(&idx);
}

return RET_OK;
}

@@ -307,7 +324,7 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem
// get parameter input index k. subgraph name + “_input_" + "k"
auto pos = subgraph->name.size() + sizeof("_input_");
auto pos2 = tensor->name.find('_', pos);
auto idx_str = tensor->name.substr(pos - 1, pos2);
auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1);
partial_idx = std::stoi(idx_str);
}

@@ -315,7 +332,7 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem
// get parameter input index k. subgraph name + “_output_" + "k"
auto pos = subgraph->name.size() + sizeof("_output_");
auto pos2 = tensor->name.find('_', pos);
auto idx_str = tensor->name.substr(pos - 1, pos2);
auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1);
partial_idx = std::stoi(idx_str);
}



+ 1
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.h View File

@@ -56,6 +56,7 @@ class SingleSwitchPass {
const std::vector<schema::CNodeT *> &subgraph_nodes);
std::unique_ptr<schema::TensorT> NewTensor(const std::unique_ptr<schema::TensorT> &in_tensor);
void RemoveUselessNode(schema::CNodeT *partial_node, schema::MetaGraphT *graph);
void DoubleIdx(uint32_t *idx);

const size_t kSwitchCondIndex = 0;
const size_t kSwitchBodyIndex = 1;


Loading…
Cancel
Save