Browse Source

resolve encoder model parser problems

tags/v1.2.0-rc1
cjh9368 5 years ago
parent
commit
2532f405de
1 changed files with 9 additions and 7 deletions
  1. +9
    -7
      mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc

+ 9
- 7
mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc View File

@@ -47,14 +47,14 @@ STATUS SwitchPass::Run(mindspore::schema::MetaGraphT *graph) {
}

STATUS SingleSwitchPass::DoubleSwitchOutput() {
origin_switch_output_tensor_indices_ = switch_node_->outputIndex;
if (origin_switch_output_tensor_indices_.size() != first_partial_node_->inputIndex.size()) {
auto cur_switch_output_tensor_indices = switch_node_->outputIndex;
if (cur_switch_output_tensor_indices.size() != first_partial_node_->inputIndex.size()) {
MS_LOG(ERROR) << "switch node: " << switch_node_->name << " input or output number is not right.";
return RET_ERROR;
}
MS_ASSERT(origin_switch_output_tensor_indices_.size() == first_partial_node_->inputIndex.szie());
for (size_t i = 0; i < origin_switch_output_tensor_indices_.size(); i++) {
auto &switch_out_tensor = graph_->allTensors.at(origin_switch_output_tensor_indices_[i]);
for (size_t i = 0; i < cur_switch_output_tensor_indices.size(); i++) {
auto &switch_out_tensor = graph_->allTensors.at(cur_switch_output_tensor_indices[i]);
const auto &cond_partial_input_tensor = graph_->allTensors.at(first_partial_node_->inputIndex[i]);
switch_out_tensor->dataType = cond_partial_input_tensor->dataType;
auto tensor = NewTensor(switch_out_tensor);
@@ -293,7 +293,7 @@ STATUS SingleSwitchPass::InsertPartialAndMergeAfterSwitch() {
return ret;
}

switch_node_->inputIndex.erase(switch_node_->inputIndex.begin(), switch_node_->inputIndex.begin() + 3);
switch_node_->inputIndex.erase(switch_node_->inputIndex.begin(), switch_node_->inputIndex.begin() + 2);
MS_ASSERT(switch_node_->outputIndex.size() % 2 == 0);
first_partial_node_->inputIndex.assign(switch_node_->outputIndex.begin(),
switch_node_->outputIndex.begin() + switch_node_->outputIndex.size() / 2);
@@ -328,8 +328,10 @@ STATUS SingleSwitchPass::InsertPartialAndMergeAfterSwitch() {
}

if (second_graph_nodes_.empty()) {
merge_node->inputIndex.assign(switch_node_->outputIndex.begin(),
switch_node_->outputIndex.begin() + second_partial_node_->outputIndex.size());
merge_node->inputIndex.insert(merge_node->inputIndex.end(),
switch_node_->outputIndex.begin() + switch_node_->outputIndex.size() / 2,
switch_node_->outputIndex.begin() + switch_node_->outputIndex.size() / 2 +
second_partial_node_->outputIndex.size());
second_subgraph_index_ = -1;
IsolateUselessNode(second_partial_node_, graph_);
} else {


Loading…
Cancel
Save