|
|
|
@@ -16,6 +16,7 @@ |
|
|
|
|
|
|
|
#include <vector> |
|
|
|
#include <map> |
|
|
|
#include <set> |
|
|
|
#include <algorithm> |
|
|
|
#include "tools/converter/legacy_optimizer/graph/switch_pass.h" |
|
|
|
#include "src/common/log_adapter.h" |
|
|
|
@@ -47,7 +48,10 @@ STATUS SwitchPass::Run(mindspore::schema::MetaGraphT *graph) { |
|
|
|
|
|
|
|
STATUS SingleSwitchPass::DoubleSwitchOutput() { |
|
|
|
origin_switch_output_tensor_indices_ = switch_node_->outputIndex; |
|
|
|
MS_ASSERT(origin_switch_output_tensor_indices_.size() == cond_partial_node_->inputIndex.szie()); |
|
|
|
if (origin_switch_output_tensor_indices_.size() != cond_partial_node_->inputIndex.size()) { |
|
|
|
MS_LOG(ERROR) << "switch node: " << switch_node_->name << " input or output number is not right."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < origin_switch_output_tensor_indices_.size(); i++) { |
|
|
|
auto &switch_out_tensor = graph_->allTensors.at(origin_switch_output_tensor_indices_[i]); |
|
|
|
const auto &cond_partial_input_tensor = graph_->allTensors.at(cond_partial_node_->inputIndex[i]); |
|
|
|
@@ -60,7 +64,7 @@ STATUS SingleSwitchPass::DoubleSwitchOutput() { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
void SingleSwitchPass::DoubleIdx(uint32_t *idx) { |
|
|
|
void SingleSwitchPass::UpdateSwitchOutputIndices(uint32_t *idx) { |
|
|
|
auto iter = std::find(switch_node_->outputIndex.begin(), switch_node_->outputIndex.end(), *idx); |
|
|
|
if (iter != switch_node_->outputIndex.end()) { |
|
|
|
int pos = iter - switch_node_->outputIndex.begin(); |
|
|
|
@@ -69,25 +73,21 @@ void SingleSwitchPass::DoubleIdx(uint32_t *idx) { |
|
|
|
} |
|
|
|
|
|
|
|
STATUS SingleSwitchPass::UpdateSwitchUser() { |
|
|
|
std::vector<CNodeT *> switch_users; |
|
|
|
for (auto &node_idx : graph_->subGraph.at(this_subgraph_index_)->nodeIndices) { |
|
|
|
auto &node = graph_->nodes.at(node_idx); |
|
|
|
for (auto &idx : node->inputIndex) { |
|
|
|
if (IsContain(switch_node_->outputIndex, idx)) { |
|
|
|
switch_users.push_back(node.get()); |
|
|
|
} |
|
|
|
DoubleIdx(&idx); |
|
|
|
UpdateSwitchOutputIndices(&idx); |
|
|
|
} |
|
|
|
} |
|
|
|
// update graph switch user |
|
|
|
for (auto &subgraph : graph_->subGraph) { |
|
|
|
for (auto &idx : subgraph->outputIndices) { |
|
|
|
DoubleIdx(&idx); |
|
|
|
UpdateSwitchOutputIndices(&idx); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &idx : graph_->outputIndex) { |
|
|
|
DoubleIdx(&idx); |
|
|
|
UpdateSwitchOutputIndices(&idx); |
|
|
|
} |
|
|
|
|
|
|
|
return RET_OK; |
|
|
|
@@ -104,20 +104,71 @@ bool SingleSwitchPass::IsLoop() { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
std::unique_ptr<schema::TensorT> SingleSwitchPass::NewTensor(const std::unique_ptr<schema::TensorT> &in_tensor) { |
|
|
|
std::unique_ptr<schema::TensorT> SingleSwitchPass::NewTensor(const std::unique_ptr<schema::TensorT> &in_tensor, |
|
|
|
bool with_data) { |
|
|
|
auto out_tensor = std::make_unique<schema::TensorT>(); |
|
|
|
out_tensor->nodeType = in_tensor->nodeType; |
|
|
|
out_tensor->dims = in_tensor->dims; |
|
|
|
out_tensor->dataType = in_tensor->dataType; |
|
|
|
out_tensor->data = in_tensor->data; |
|
|
|
out_tensor->format = in_tensor->format; |
|
|
|
if (with_data) { |
|
|
|
out_tensor->data = in_tensor->data; |
|
|
|
} |
|
|
|
return out_tensor; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS SingleSwitchPass::BodyGraphVariableInput(std::vector<size_t> *variable_input) { |
|
|
|
auto &body_fg = graph_->subGraph.at(body_subgraph_index_); |
|
|
|
auto body_fg_output = body_fg->outputIndices; |
|
|
|
for (auto &subgraph_output : body_fg_output) { |
|
|
|
for (auto &node : body_graph_nodes_) { |
|
|
|
if (node != nullptr && IsContain(node->outputIndex, subgraph_output)) { |
|
|
|
int partial_idx = GetSubgraphOutputTensorIndex(body_fg, node); |
|
|
|
if (partial_idx == -1) { |
|
|
|
MS_LOG(ERROR) << "get input index failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
(*variable_input).emplace_back(partial_idx); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS SingleSwitchPass::InsertMerge() { |
|
|
|
int ret = RET_OK; |
|
|
|
// update body graph output |
|
|
|
auto &body_fg = graph_->subGraph.at(body_subgraph_index_); |
|
|
|
body_fg->outputIndices.assign(body_to_cond_partial_node_->inputIndex.begin(), |
|
|
|
body_to_cond_partial_node_->inputIndex.end()); |
|
|
|
|
|
|
|
// remove body_to_cond_partial_node_ from body_graph_nodes_ |
|
|
|
for (auto it = body_graph_nodes_.begin(); it != body_graph_nodes_.end();) { |
|
|
|
if (*it == body_to_cond_partial_node_) { |
|
|
|
it = body_graph_nodes_.erase(it); |
|
|
|
} else { |
|
|
|
it++; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// isolate body_to_cond_partial_node_ |
|
|
|
IsolateUselessNode(body_to_cond_partial_node_, graph_); |
|
|
|
|
|
|
|
std::vector<size_t> variable_input{}; |
|
|
|
int ret = BodyGraphVariableInput(&variable_input); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "get body graph variable input failed, ret: " << ret; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> const_input{}; |
|
|
|
for (size_t i = 0; i < body_partial_node_->inputIndex.size(); i++) { |
|
|
|
if (IsContain(variable_input, i)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
const_input.push_back(i); |
|
|
|
} |
|
|
|
|
|
|
|
auto merge_node = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT); |
|
|
|
MS_ASSERT(merge_node != nullptr); |
|
|
|
auto primitiveT = std::unique_ptr<PrimitiveT>(new (std::nothrow) PrimitiveT); |
|
|
|
MS_ASSERT(primitiveT != nullptr); |
|
|
|
merge_node->primitive = std::move(primitiveT); |
|
|
|
@@ -129,8 +180,6 @@ STATUS SingleSwitchPass::InsertMerge() { |
|
|
|
MS_ASSERT(merge_param != nullptr); |
|
|
|
merge_node->primitive->value.value = merge_param.release(); |
|
|
|
|
|
|
|
merge_node->inputIndex.assign(cond_partial_node_->inputIndex.begin(), cond_partial_node_->inputIndex.end()); |
|
|
|
|
|
|
|
// merge node output is same as switch |
|
|
|
for (auto &out_index : origin_switch_output_tensor_indices_) { |
|
|
|
auto &switch_out_tensor = graph_->allTensors.at(out_index); |
|
|
|
@@ -139,12 +188,30 @@ STATUS SingleSwitchPass::InsertMerge() { |
|
|
|
merge_node->outputIndex.push_back(graph_->allTensors.size() - 1); |
|
|
|
} |
|
|
|
|
|
|
|
// double merge inputs to contain the outputs of body node |
|
|
|
for (auto &index : cond_partial_node_->inputIndex) { |
|
|
|
auto &in_tensor = graph_->allTensors.at(index); |
|
|
|
auto tensor = NewTensor(in_tensor); |
|
|
|
graph_->allTensors.push_back(std::move(tensor)); |
|
|
|
merge_node->inputIndex.push_back(graph_->allTensors.size() - 1); |
|
|
|
merge_node->inputIndex.assign(cond_partial_node_->inputIndex.begin(), cond_partial_node_->inputIndex.end()); |
|
|
|
|
|
|
|
std::set<uint32_t> input_set{}; |
|
|
|
for (auto &iter : merge_node->inputIndex) { |
|
|
|
if (input_set.find(iter) != input_set.end()) { |
|
|
|
auto &in_tensor = graph_->allTensors.at(iter); |
|
|
|
auto tensor = NewTensor(in_tensor, true); |
|
|
|
graph_->allTensors.push_back(std::move(tensor)); |
|
|
|
iter = graph_->allTensors.size() - 1; |
|
|
|
} |
|
|
|
input_set.insert(iter); |
|
|
|
} |
|
|
|
|
|
|
|
// double merge inputs to contain the outputs of body node |
|
|
|
auto old_merge_input = merge_node->inputIndex; |
|
|
|
for (size_t i = 0; i < old_merge_input.size(); i++) { |
|
|
|
auto &in_tensor = graph_->allTensors.at(old_merge_input[i]); |
|
|
|
if (IsContain(const_input, i)) { |
|
|
|
merge_node->inputIndex.push_back(old_merge_input[i]); |
|
|
|
} else { |
|
|
|
auto tensor = NewTensor(in_tensor); |
|
|
|
graph_->allTensors.push_back(std::move(tensor)); |
|
|
|
merge_node->inputIndex.push_back(graph_->allTensors.size() - 1); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// insert merge node before the cond graph |
|
|
|
@@ -182,46 +249,12 @@ STATUS SingleSwitchPass::InsertMerge() { |
|
|
|
graph_->nodes.push_back(std::move(merge_node)); |
|
|
|
graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1); |
|
|
|
|
|
|
|
// update bodu graph output |
|
|
|
graph_->subGraph.at(body_subgraph_index_) |
|
|
|
->outputIndices.assign(body_to_cond_partial_node_->inputIndex.begin(), |
|
|
|
body_to_cond_partial_node_->inputIndex.end()); |
|
|
|
|
|
|
|
// erase body_to_cond_partial_node_ |
|
|
|
RemoveUselessNode(body_to_cond_partial_node_, graph_); |
|
|
|
return ret; |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
void SingleSwitchPass::RemoveUselessNode(schema::CNodeT *partial_node, schema::MetaGraphT *graph) { |
|
|
|
void SingleSwitchPass::IsolateUselessNode(schema::CNodeT *partial_node, schema::MetaGraphT *graph) { |
|
|
|
partial_node->inputIndex.clear(); |
|
|
|
partial_node->outputIndex.clear(); |
|
|
|
|
|
|
|
int pos = -1; |
|
|
|
for (size_t i = 0; i < graph->nodes.size(); ++i) { |
|
|
|
if (graph->nodes.at(i).get() == partial_node) { |
|
|
|
pos = i; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (pos == -1) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
graph->nodes.erase(graph->nodes.begin() + pos); |
|
|
|
|
|
|
|
for (auto &subgraph : graph->subGraph) { |
|
|
|
for (auto it = subgraph->nodeIndices.begin(); it != subgraph->nodeIndices.end();) { |
|
|
|
if (*it == static_cast<uint32_t>(pos)) { |
|
|
|
it = subgraph->nodeIndices.erase(it); |
|
|
|
} else { |
|
|
|
if (*it > static_cast<uint32_t>(pos)) { |
|
|
|
(*it)--; |
|
|
|
} |
|
|
|
it++; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
size_t SingleSwitchPass::InitThisGraphIndex() { |
|
|
|
@@ -265,12 +298,10 @@ STATUS SingleSwitchPass::Init() { |
|
|
|
for (auto &out_index : iter->get()->outputIndex) { |
|
|
|
if (out_index == switch_node_->inputIndex[kSwitchCondIndex]) { |
|
|
|
cond_partial_node_ = iter->get(); |
|
|
|
cond_node_index_ = iter - graph_->nodes.begin(); |
|
|
|
find_cond_node = true; |
|
|
|
} |
|
|
|
if (out_index == switch_node_->inputIndex[kSwitchBodyIndex]) { |
|
|
|
body_partial_node_ = iter->get(); |
|
|
|
body_node_index_ = iter - graph_->nodes.begin(); |
|
|
|
find_body_node = true; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -301,6 +332,41 @@ STATUS SingleSwitchPass::Init() { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int SingleSwitchPass::GetSubgraphInputTensorIndex(const std::unique_ptr<SubGraphT> &subgraph, |
|
|
|
const std::unique_ptr<TensorT> &tensor) { |
|
|
|
int partial_idx = -1; |
|
|
|
if (tensor->name.find("_input_") != std::string::npos) { |
|
|
|
// get parameter input index k. subgraph name + “_input_" + "k" |
|
|
|
auto pos = subgraph->name.size() + sizeof("_input_"); |
|
|
|
auto pos2 = tensor->name.find('_', pos); |
|
|
|
auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1); |
|
|
|
partial_idx = std::stoi(idx_str); |
|
|
|
} |
|
|
|
|
|
|
|
if (tensor->name.find("_output_") != std::string::npos) { |
|
|
|
// get parameter input index k. subgraph name + “_output_" + "k" |
|
|
|
auto pos = subgraph->name.size() + sizeof("_output_"); |
|
|
|
auto pos2 = tensor->name.find('_', pos); |
|
|
|
auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1); |
|
|
|
partial_idx = std::stoi(idx_str); |
|
|
|
} |
|
|
|
return partial_idx; |
|
|
|
} |
|
|
|
|
|
|
|
int SingleSwitchPass::GetSubgraphOutputTensorIndex(const std::unique_ptr<SubGraphT> &subgraph, CNodeT *node) { |
|
|
|
int partial_idx = -1; |
|
|
|
if (node->name == "LogicalAnd") { |
|
|
|
partial_idx = 0; |
|
|
|
} else { |
|
|
|
// get parameter input index k. subgraph name + “_output_" + "k" |
|
|
|
auto pos = subgraph->name.size() + sizeof("_output_"); |
|
|
|
auto pos2 = node->name.find('_', pos); |
|
|
|
auto idx_str = node->name.substr(pos - 1, pos2 - pos + 1); |
|
|
|
partial_idx = std::stoi(idx_str); |
|
|
|
} |
|
|
|
return partial_idx; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schema::CNodeT *partial_node, |
|
|
|
const std::vector<schema::CNodeT *> &subgraph_nodes) { |
|
|
|
if (partial_node == nullptr || subgraph_nodes.empty()) { |
|
|
|
@@ -315,27 +381,11 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem |
|
|
|
std::vector<std::pair<int, int>> tmp_inputs_order{}; |
|
|
|
for (unsigned int &subgraph_input : subgraph_inputs) { |
|
|
|
auto &tensor = graph_->allTensors.at(subgraph_input); |
|
|
|
if (tensor->name.size() < subgraph->name.size() + 8) { |
|
|
|
MS_LOG(ERROR) << "tensor name: " << tensor->name << " not right."; |
|
|
|
int partial_idx = GetSubgraphInputTensorIndex(subgraph, tensor); |
|
|
|
if (partial_idx == -1) { |
|
|
|
MS_LOG(ERROR) << "get input index failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
int partial_idx = -1; |
|
|
|
if (tensor->name.find("_input_") != std::string::npos) { |
|
|
|
// get parameter input index k. subgraph name + “_input_" + "k" |
|
|
|
auto pos = subgraph->name.size() + sizeof("_input_"); |
|
|
|
auto pos2 = tensor->name.find('_', pos); |
|
|
|
auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1); |
|
|
|
partial_idx = std::stoi(idx_str); |
|
|
|
} |
|
|
|
|
|
|
|
if (tensor->name.find("_output_") != std::string::npos) { |
|
|
|
// get parameter input index k. subgraph name + “_output_" + "k" |
|
|
|
auto pos = subgraph->name.size() + sizeof("_output_"); |
|
|
|
auto pos2 = tensor->name.find('_', pos); |
|
|
|
auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1); |
|
|
|
partial_idx = std::stoi(idx_str); |
|
|
|
} |
|
|
|
|
|
|
|
subgraph_input_map.insert(std::pair<int, int>{subgraph_input, partial_inputs[partial_idx]}); |
|
|
|
tmp_inputs_order.emplace_back(partial_idx, partial_inputs[partial_idx]); |
|
|
|
} |
|
|
|
@@ -374,15 +424,10 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche |
|
|
|
for (unsigned int &subgraph_output : subgraph_outputs) { |
|
|
|
for (auto &node : subgraph_nodes) { |
|
|
|
if (IsContain(node->outputIndex, subgraph_output)) { |
|
|
|
int partial_idx = -1; |
|
|
|
if (node->name == "LogicalAnd") { |
|
|
|
partial_idx = 0; |
|
|
|
} else { |
|
|
|
// get parameter input index k. subgraph name + “_output_" + "k" |
|
|
|
auto pos = subgraph->name.size() + sizeof("_output_"); |
|
|
|
auto pos2 = node->name.find('_', pos); |
|
|
|
auto idx_str = node->name.substr(pos - 1, pos2); |
|
|
|
partial_idx = std::stoi(idx_str); |
|
|
|
int partial_idx = GetSubgraphOutputTensorIndex(subgraph, node); |
|
|
|
if (partial_idx == -1) { |
|
|
|
MS_LOG(ERROR) << "get input index failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
subgraph_output_map.insert(std::pair<int, int>{subgraph_output, partial_outputs[partial_idx]}); |
|
|
|
tmp_outputs_order.emplace_back(partial_idx, partial_outputs[partial_idx]); |
|
|
|
@@ -473,7 +518,6 @@ STATUS SingleSwitchPass::Run() { |
|
|
|
MS_LOG(ERROR) << "ConcatBodySubgraphInputAndOutput failed, ret: " << ret; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
} // namespace mindspore::lite |