diff --git a/tools/onnx/onnx2ncnn.cpp b/tools/onnx/onnx2ncnn.cpp index 3b8edc6f3..679cc91f6 100644 --- a/tools/onnx/onnx2ncnn.cpp +++ b/tools/onnx/onnx2ncnn.cpp @@ -1824,15 +1824,15 @@ static void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) +static void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); - // LSTM <= LSTM - Transpose - Reshape - Transpose - if (node->op_type() == "LSTM") + // LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose + if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") { if (node_reference[node->output(0)] != 1) continue; @@ -1914,12 +1914,15 @@ static void fuse_bilstm(onnx::GraphProto* mutable_graph, std::mapoutput(0)] != 1) + continue; + onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1); if (node4->op_type() != "Transpose") continue; - if (node4->input(0) != node3->output(0)) + if (node4->input(0) != node->output(0)) continue; // 1 0 2 @@ -1946,6 +1949,96 @@ static void fuse_bilstm(onnx::GraphProto* mutable_graph, std::mapmutable_node(i); + + // LSTM(uni) <= LSTM(uni) - Squeeze - Transpose + if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") + { + if (node_reference[node->output(0)] != 1) + continue; + + if (i + 1 >= node_count) + continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + + if (node2->op_type() != "Squeeze") + continue; + + if (node2->input(0) != node->output(0)) + continue; + + std::string direction = get_node_attr_s(*node, "direction"); + if (direction == "bidirectional") + continue; + + // 1 + std::vector axes = get_node_attr_ai(*node2, "axes"); + if (axes.size() != 1) + continue; + + if (axes[0] != 1) + continue; + + // reduce + node2->set_op_type("noop_reducedncnn"); + + node_reference[node->output(0)] -= 1; + + blob_names.erase(node->output(0)); + if (node->output_size() > 1) + { + for (int j = 1; j < node->output_size(); j++) + { + blob_names.erase(node->output(j)); + } + } + + node->clear_output(); + node->add_output(node2->output(0)); + + reduced_node_count += 1; + i += 1; + + if (i + 1 < node_count) + { + if (node_reference[node2->output(0)] != 1) + continue; + + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1); + + if (node3->op_type() != "Transpose") + continue; + + if (node3->input(0) != node->output(0)) + continue; + + // 1 0 2 + std::vector perm4 = get_node_attr_ai(*node3, "perm"); + if (perm4.size() != 3) + continue; + + if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2) + continue; + + // reduce + node3->set_op_type("noop_reducedncnn"); + + node_reference[node->output(0)] -= 1; + + blob_names.erase(node->output(0)); + + node->clear_output(); + node->add_output(node3->output(0)); + + reduced_node_count += 1; + i += 1; + } + } + } + for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); @@ -1969,7 +2062,7 @@ static void fuse_bilstm(onnx::GraphProto* mutable_graph, std::mapmutable_node(i + 1); - if (node2->op_type() != "LSTM") + if (node2->op_type() != "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") continue; if (node2->input(0) != node->output(0)) @@ -2124,7 +2217,7 @@ int main(int argc, char** argv) fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count); fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count); fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_bilstm(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_lstm_gru_rnn(mutable_graph, weights, node_reference, blob_names, reduced_node_count); // reduce common const weight node_reference for (int i = 0; i < node_count; i++) @@ -2195,6 +2288,13 @@ int main(int argc, char** argv) node_reference[node.input(2)] -= 1; } } + else if (op == "GRU") + { + for (int j = 1; j < node.input_size(); j++) + { + node_reference[node.input(j)] -= 1; + } + } else if (op == "InstanceNormalization") { node_reference[node.input(1)] -= 1; @@ -2251,6 +2351,13 @@ int main(int argc, char** argv) } } } + else if (op == "RNN") + { + for (int j = 1; j < node.input_size(); j++) + { + node_reference[node.input(j)] -= 1; + } + } else if (op == "Slice") { if (node.input_size() >= 2)