Browse Source

convert onnx gru rnn graph

tags/20210124
nihuini 5 years ago
parent
commit
cf853bd3ce
1 changed files with 113 additions and 6 deletions
  1. +113
    -6
      tools/onnx/onnx2ncnn.cpp

+ 113
- 6
tools/onnx/onnx2ncnn.cpp View File

@@ -1824,15 +1824,15 @@ static void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map<std:
}
}

static void fuse_bilstm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
static void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
{
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++)
{
onnx::NodeProto* node = mutable_graph->mutable_node(i);

// LSTM <= LSTM - Transpose - Reshape - Transpose
if (node->op_type() == "LSTM")
// LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose
if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
{
if (node_reference[node->output(0)] != 1)
continue;
@@ -1914,12 +1914,15 @@ static void fuse_bilstm(onnx::GraphProto* mutable_graph, std::map<std::string, o

if (i + 1 < node_count)
{
if (node_reference[node3->output(0)] != 1)
continue;

onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1);

if (node4->op_type() != "Transpose")
continue;

if (node4->input(0) != node3->output(0))
if (node4->input(0) != node->output(0))
continue;

// 1 0 2
@@ -1946,6 +1949,96 @@ static void fuse_bilstm(onnx::GraphProto* mutable_graph, std::map<std::string, o
}
}

for (int i = 0; i < node_count; i++)
{
onnx::NodeProto* node = mutable_graph->mutable_node(i);

// LSTM(uni) <= LSTM(uni) - Squeeze - Transpose
if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
{
if (node_reference[node->output(0)] != 1)
continue;

if (i + 1 >= node_count)
continue;

onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);

if (node2->op_type() != "Squeeze")
continue;

if (node2->input(0) != node->output(0))
continue;

std::string direction = get_node_attr_s(*node, "direction");
if (direction == "bidirectional")
continue;

// 1
std::vector<int> axes = get_node_attr_ai(*node2, "axes");
if (axes.size() != 1)
continue;

if (axes[0] != 1)
continue;

// reduce
node2->set_op_type("noop_reducedncnn");

node_reference[node->output(0)] -= 1;

blob_names.erase(node->output(0));
if (node->output_size() > 1)
{
for (int j = 1; j < node->output_size(); j++)
{
blob_names.erase(node->output(j));
}
}

node->clear_output();
node->add_output(node2->output(0));

reduced_node_count += 1;
i += 1;

if (i + 1 < node_count)
{
if (node_reference[node2->output(0)] != 1)
continue;

onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1);

if (node3->op_type() != "Transpose")
continue;

if (node3->input(0) != node->output(0))
continue;

// 1 0 2
std::vector<int> perm4 = get_node_attr_ai(*node3, "perm");
if (perm4.size() != 3)
continue;

if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2)
continue;

// reduce
node3->set_op_type("noop_reducedncnn");

node_reference[node->output(0)] -= 1;

blob_names.erase(node->output(0));

node->clear_output();
node->add_output(node3->output(0));

reduced_node_count += 1;
i += 1;
}
}
}

for (int i = 0; i < node_count; i++)
{
onnx::NodeProto* node = mutable_graph->mutable_node(i);
@@ -1969,7 +2062,7 @@ static void fuse_bilstm(onnx::GraphProto* mutable_graph, std::map<std::string, o

onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);

if (node2->op_type() != "LSTM")
if (node2->op_type() != "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
continue;

if (node2->input(0) != node->output(0))
@@ -2124,7 +2217,7 @@ int main(int argc, char** argv)
fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_bilstm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_lstm_gru_rnn(mutable_graph, weights, node_reference, blob_names, reduced_node_count);

// reduce common const weight node_reference
for (int i = 0; i < node_count; i++)
@@ -2195,6 +2288,13 @@ int main(int argc, char** argv)
node_reference[node.input(2)] -= 1;
}
}
else if (op == "GRU")
{
for (int j = 1; j < node.input_size(); j++)
{
node_reference[node.input(j)] -= 1;
}
}
else if (op == "InstanceNormalization")
{
node_reference[node.input(1)] -= 1;
@@ -2251,6 +2351,13 @@ int main(int argc, char** argv)
}
}
}
else if (op == "RNN")
{
for (int j = 1; j < node.input_size(); j++)
{
node_reference[node.input(j)] -= 1;
}
}
else if (op == "Slice")
{
if (node.input_size() >= 2)


Loading…
Cancel
Save