| @@ -1824,15 +1824,15 @@ static void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map<std: | |||||
| } | } | ||||
| } | } | ||||
| static void fuse_bilstm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count) | |||||
| static void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count) | |||||
| { | { | ||||
| int node_count = mutable_graph->node_size(); | int node_count = mutable_graph->node_size(); | ||||
| for (int i = 0; i < node_count; i++) | for (int i = 0; i < node_count; i++) | ||||
| { | { | ||||
| onnx::NodeProto* node = mutable_graph->mutable_node(i); | onnx::NodeProto* node = mutable_graph->mutable_node(i); | ||||
| // LSTM <= LSTM - Transpose - Reshape - Transpose | |||||
| if (node->op_type() == "LSTM") | |||||
| // LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose | |||||
| if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") | |||||
| { | { | ||||
| if (node_reference[node->output(0)] != 1) | if (node_reference[node->output(0)] != 1) | ||||
| continue; | continue; | ||||
| @@ -1914,12 +1914,15 @@ static void fuse_bilstm(onnx::GraphProto* mutable_graph, std::map<std::string, o | |||||
| if (i + 1 < node_count) | if (i + 1 < node_count) | ||||
| { | { | ||||
| if (node_reference[node3->output(0)] != 1) | |||||
| continue; | |||||
| onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1); | onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1); | ||||
| if (node4->op_type() != "Transpose") | if (node4->op_type() != "Transpose") | ||||
| continue; | continue; | ||||
| if (node4->input(0) != node3->output(0)) | |||||
| if (node4->input(0) != node->output(0)) | |||||
| continue; | continue; | ||||
| // 1 0 2 | // 1 0 2 | ||||
| @@ -1946,6 +1949,96 @@ static void fuse_bilstm(onnx::GraphProto* mutable_graph, std::map<std::string, o | |||||
| } | } | ||||
| } | } | ||||
| for (int i = 0; i < node_count; i++) | |||||
| { | |||||
| onnx::NodeProto* node = mutable_graph->mutable_node(i); | |||||
| // LSTM(uni) <= LSTM(uni) - Squeeze - Transpose | |||||
| if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") | |||||
| { | |||||
| if (node_reference[node->output(0)] != 1) | |||||
| continue; | |||||
| if (i + 1 >= node_count) | |||||
| continue; | |||||
| onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); | |||||
| if (node2->op_type() != "Squeeze") | |||||
| continue; | |||||
| if (node2->input(0) != node->output(0)) | |||||
| continue; | |||||
| std::string direction = get_node_attr_s(*node, "direction"); | |||||
| if (direction == "bidirectional") | |||||
| continue; | |||||
| // 1 | |||||
| std::vector<int> axes = get_node_attr_ai(*node2, "axes"); | |||||
| if (axes.size() != 1) | |||||
| continue; | |||||
| if (axes[0] != 1) | |||||
| continue; | |||||
| // reduce | |||||
| node2->set_op_type("noop_reducedncnn"); | |||||
| node_reference[node->output(0)] -= 1; | |||||
| blob_names.erase(node->output(0)); | |||||
| if (node->output_size() > 1) | |||||
| { | |||||
| for (int j = 1; j < node->output_size(); j++) | |||||
| { | |||||
| blob_names.erase(node->output(j)); | |||||
| } | |||||
| } | |||||
| node->clear_output(); | |||||
| node->add_output(node2->output(0)); | |||||
| reduced_node_count += 1; | |||||
| i += 1; | |||||
| if (i + 1 < node_count) | |||||
| { | |||||
| if (node_reference[node2->output(0)] != 1) | |||||
| continue; | |||||
| onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1); | |||||
| if (node3->op_type() != "Transpose") | |||||
| continue; | |||||
| if (node3->input(0) != node->output(0)) | |||||
| continue; | |||||
| // 1 0 2 | |||||
| std::vector<int> perm4 = get_node_attr_ai(*node3, "perm"); | |||||
| if (perm4.size() != 3) | |||||
| continue; | |||||
| if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2) | |||||
| continue; | |||||
| // reduce | |||||
| node3->set_op_type("noop_reducedncnn"); | |||||
| node_reference[node->output(0)] -= 1; | |||||
| blob_names.erase(node->output(0)); | |||||
| node->clear_output(); | |||||
| node->add_output(node3->output(0)); | |||||
| reduced_node_count += 1; | |||||
| i += 1; | |||||
| } | |||||
| } | |||||
| } | |||||
| for (int i = 0; i < node_count; i++) | for (int i = 0; i < node_count; i++) | ||||
| { | { | ||||
| onnx::NodeProto* node = mutable_graph->mutable_node(i); | onnx::NodeProto* node = mutable_graph->mutable_node(i); | ||||
| @@ -1969,7 +2062,7 @@ static void fuse_bilstm(onnx::GraphProto* mutable_graph, std::map<std::string, o | |||||
| onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); | onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); | ||||
| if (node2->op_type() != "LSTM") | |||||
| if (node2->op_type() != "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") | |||||
| continue; | continue; | ||||
| if (node2->input(0) != node->output(0)) | if (node2->input(0) != node->output(0)) | ||||
| @@ -2124,7 +2217,7 @@ int main(int argc, char** argv) | |||||
| fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count); | fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count); | ||||
| fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count); | fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count); | ||||
| fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count); | fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count); | ||||
| fuse_bilstm(mutable_graph, weights, node_reference, blob_names, reduced_node_count); | |||||
| fuse_lstm_gru_rnn(mutable_graph, weights, node_reference, blob_names, reduced_node_count); | |||||
| // reduce common const weight node_reference | // reduce common const weight node_reference | ||||
| for (int i = 0; i < node_count; i++) | for (int i = 0; i < node_count; i++) | ||||
| @@ -2195,6 +2288,13 @@ int main(int argc, char** argv) | |||||
| node_reference[node.input(2)] -= 1; | node_reference[node.input(2)] -= 1; | ||||
| } | } | ||||
| } | } | ||||
| else if (op == "GRU") | |||||
| { | |||||
| for (int j = 1; j < node.input_size(); j++) | |||||
| { | |||||
| node_reference[node.input(j)] -= 1; | |||||
| } | |||||
| } | |||||
| else if (op == "InstanceNormalization") | else if (op == "InstanceNormalization") | ||||
| { | { | ||||
| node_reference[node.input(1)] -= 1; | node_reference[node.input(1)] -= 1; | ||||
| @@ -2251,6 +2351,13 @@ int main(int argc, char** argv) | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| else if (op == "RNN") | |||||
| { | |||||
| for (int j = 1; j < node.input_size(); j++) | |||||
| { | |||||
| node_reference[node.input(j)] -= 1; | |||||
| } | |||||
| } | |||||
| else if (op == "Slice") | else if (op == "Slice") | ||||
| { | { | ||||
| if (node.input_size() >= 2) | if (node.input_size() >= 2) | ||||