|
|
|
@@ -47,14 +47,14 @@ STATUS SwitchPass::Run(mindspore::schema::MetaGraphT *graph) { |
|
|
|
} |
|
|
|
|
|
|
|
STATUS SingleSwitchPass::DoubleSwitchOutput() { |
|
|
|
origin_switch_output_tensor_indices_ = switch_node_->outputIndex; |
|
|
|
if (origin_switch_output_tensor_indices_.size() != first_partial_node_->inputIndex.size()) { |
|
|
|
auto cur_switch_output_tensor_indices = switch_node_->outputIndex; |
|
|
|
if (cur_switch_output_tensor_indices.size() != first_partial_node_->inputIndex.size()) { |
|
|
|
MS_LOG(ERROR) << "switch node: " << switch_node_->name << " input or output number is not right."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
MS_ASSERT(origin_switch_output_tensor_indices_.size() == first_partial_node_->inputIndex.szie()); |
|
|
|
for (size_t i = 0; i < origin_switch_output_tensor_indices_.size(); i++) { |
|
|
|
auto &switch_out_tensor = graph_->allTensors.at(origin_switch_output_tensor_indices_[i]); |
|
|
|
for (size_t i = 0; i < cur_switch_output_tensor_indices.size(); i++) { |
|
|
|
auto &switch_out_tensor = graph_->allTensors.at(cur_switch_output_tensor_indices[i]); |
|
|
|
const auto &cond_partial_input_tensor = graph_->allTensors.at(first_partial_node_->inputIndex[i]); |
|
|
|
switch_out_tensor->dataType = cond_partial_input_tensor->dataType; |
|
|
|
auto tensor = NewTensor(switch_out_tensor); |
|
|
|
@@ -293,7 +293,7 @@ STATUS SingleSwitchPass::InsertPartialAndMergeAfterSwitch() { |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
switch_node_->inputIndex.erase(switch_node_->inputIndex.begin(), switch_node_->inputIndex.begin() + 3); |
|
|
|
switch_node_->inputIndex.erase(switch_node_->inputIndex.begin(), switch_node_->inputIndex.begin() + 2); |
|
|
|
MS_ASSERT(switch_node_->outputIndex.size() % 2 == 0); |
|
|
|
first_partial_node_->inputIndex.assign(switch_node_->outputIndex.begin(), |
|
|
|
switch_node_->outputIndex.begin() + switch_node_->outputIndex.size() / 2); |
|
|
|
@@ -328,8 +328,10 @@ STATUS SingleSwitchPass::InsertPartialAndMergeAfterSwitch() { |
|
|
|
} |
|
|
|
|
|
|
|
if (second_graph_nodes_.empty()) { |
|
|
|
merge_node->inputIndex.assign(switch_node_->outputIndex.begin(), |
|
|
|
switch_node_->outputIndex.begin() + second_partial_node_->outputIndex.size()); |
|
|
|
merge_node->inputIndex.insert(merge_node->inputIndex.end(), |
|
|
|
switch_node_->outputIndex.begin() + switch_node_->outputIndex.size() / 2, |
|
|
|
switch_node_->outputIndex.begin() + switch_node_->outputIndex.size() / 2 + |
|
|
|
second_partial_node_->outputIndex.size()); |
|
|
|
second_subgraph_index_ = -1; |
|
|
|
IsolateUselessNode(second_partial_node_, graph_); |
|
|
|
} else { |
|
|
|
|