| @@ -1824,15 +1824,15 @@ static void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map<std: | |||
| } | |||
| } | |||
| static void fuse_bilstm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count) | |||
| static void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count) | |||
| { | |||
| int node_count = mutable_graph->node_size(); | |||
| for (int i = 0; i < node_count; i++) | |||
| { | |||
| onnx::NodeProto* node = mutable_graph->mutable_node(i); | |||
| // LSTM <= LSTM - Transpose - Reshape - Transpose | |||
| if (node->op_type() == "LSTM") | |||
| // LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose | |||
| if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") | |||
| { | |||
| if (node_reference[node->output(0)] != 1) | |||
| continue; | |||
| @@ -1914,12 +1914,15 @@ static void fuse_bilstm(onnx::GraphProto* mutable_graph, std::map<std::string, o | |||
| if (i + 1 < node_count) | |||
| { | |||
| if (node_reference[node3->output(0)] != 1) | |||
| continue; | |||
| onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1); | |||
| if (node4->op_type() != "Transpose") | |||
| continue; | |||
| if (node4->input(0) != node3->output(0)) | |||
| if (node4->input(0) != node->output(0)) | |||
| continue; | |||
| // 1 0 2 | |||
| @@ -1946,6 +1949,96 @@ static void fuse_bilstm(onnx::GraphProto* mutable_graph, std::map<std::string, o | |||
| } | |||
| } | |||
| for (int i = 0; i < node_count; i++) | |||
| { | |||
| onnx::NodeProto* node = mutable_graph->mutable_node(i); | |||
| // LSTM(uni) <= LSTM(uni) - Squeeze - Transpose | |||
| if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") | |||
| { | |||
| if (node_reference[node->output(0)] != 1) | |||
| continue; | |||
| if (i + 1 >= node_count) | |||
| continue; | |||
| onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); | |||
| if (node2->op_type() != "Squeeze") | |||
| continue; | |||
| if (node2->input(0) != node->output(0)) | |||
| continue; | |||
| std::string direction = get_node_attr_s(*node, "direction"); | |||
| if (direction == "bidirectional") | |||
| continue; | |||
| // 1 | |||
| std::vector<int> axes = get_node_attr_ai(*node2, "axes"); | |||
| if (axes.size() != 1) | |||
| continue; | |||
| if (axes[0] != 1) | |||
| continue; | |||
| // reduce | |||
| node2->set_op_type("noop_reducedncnn"); | |||
| node_reference[node->output(0)] -= 1; | |||
| blob_names.erase(node->output(0)); | |||
| if (node->output_size() > 1) | |||
| { | |||
| for (int j = 1; j < node->output_size(); j++) | |||
| { | |||
| blob_names.erase(node->output(j)); | |||
| } | |||
| } | |||
| node->clear_output(); | |||
| node->add_output(node2->output(0)); | |||
| reduced_node_count += 1; | |||
| i += 1; | |||
| if (i + 1 < node_count) | |||
| { | |||
| if (node_reference[node2->output(0)] != 1) | |||
| continue; | |||
| onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1); | |||
| if (node3->op_type() != "Transpose") | |||
| continue; | |||
| if (node3->input(0) != node->output(0)) | |||
| continue; | |||
| // 1 0 2 | |||
| std::vector<int> perm4 = get_node_attr_ai(*node3, "perm"); | |||
| if (perm4.size() != 3) | |||
| continue; | |||
| if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2) | |||
| continue; | |||
| // reduce | |||
| node3->set_op_type("noop_reducedncnn"); | |||
| node_reference[node->output(0)] -= 1; | |||
| blob_names.erase(node->output(0)); | |||
| node->clear_output(); | |||
| node->add_output(node3->output(0)); | |||
| reduced_node_count += 1; | |||
| i += 1; | |||
| } | |||
| } | |||
| } | |||
| for (int i = 0; i < node_count; i++) | |||
| { | |||
| onnx::NodeProto* node = mutable_graph->mutable_node(i); | |||
| @@ -1969,7 +2062,7 @@ static void fuse_bilstm(onnx::GraphProto* mutable_graph, std::map<std::string, o | |||
| onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); | |||
| if (node2->op_type() != "LSTM") | |||
| if (node2->op_type() != "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") | |||
| continue; | |||
| if (node2->input(0) != node->output(0)) | |||
| @@ -2124,7 +2217,7 @@ int main(int argc, char** argv) | |||
| fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count); | |||
| fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count); | |||
| fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count); | |||
| fuse_bilstm(mutable_graph, weights, node_reference, blob_names, reduced_node_count); | |||
| fuse_lstm_gru_rnn(mutable_graph, weights, node_reference, blob_names, reduced_node_count); | |||
| // reduce common const weight node_reference | |||
| for (int i = 0; i < node_count; i++) | |||
| @@ -2195,6 +2288,13 @@ int main(int argc, char** argv) | |||
| node_reference[node.input(2)] -= 1; | |||
| } | |||
| } | |||
| else if (op == "GRU") | |||
| { | |||
| for (int j = 1; j < node.input_size(); j++) | |||
| { | |||
| node_reference[node.input(j)] -= 1; | |||
| } | |||
| } | |||
| else if (op == "InstanceNormalization") | |||
| { | |||
| node_reference[node.input(1)] -= 1; | |||
| @@ -2251,6 +2351,13 @@ int main(int argc, char** argv) | |||
| } | |||
| } | |||
| } | |||
| else if (op == "RNN") | |||
| { | |||
| for (int j = 1; j < node.input_size(); j++) | |||
| { | |||
| node_reference[node.input(j)] -= 1; | |||
| } | |||
| } | |||
| else if (op == "Slice") | |||
| { | |||
| if (node.input_size() >= 2) | |||