diff --git a/tools/onnx/onnx2ncnn.cpp b/tools/onnx/onnx2ncnn.cpp index ec4aa1d8a..8cdb73ca3 100644 --- a/tools/onnx/onnx2ncnn.cpp +++ b/tools/onnx/onnx2ncnn.cpp @@ -714,7 +714,6 @@ static void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) { int node_count = mutable_graph->node_size(); @@ -794,6 +793,126 @@ static void fuse_normalize(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) +{ + int node_count = mutable_graph->node_size(); + for (int i=0; imutable_node(i); + + // Flatten <= X - Shape - Gather - Constant - Unsqueeze - Unsqueeze - Concat - Reshape + if (node->op_type() == "Shape") + { + if (node_reference.find(node->output(0)) == node_reference.end() || node_reference[node->output(0)] != 1) + continue; + + if (i+6 >= node_count) + continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i+1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i+2); + onnx::NodeProto* node4 = mutable_graph->mutable_node(i+3); + onnx::NodeProto* node5 = mutable_graph->mutable_node(i+4); + onnx::NodeProto* node6 = mutable_graph->mutable_node(i+5); + onnx::NodeProto* node7 = mutable_graph->mutable_node(i+6); + + if (node2->op_type() != "Gather" || node3->op_type() != "Constant" || node4->op_type() != "Unsqueeze" || node5->op_type() != "Unsqueeze" + || node6->op_type() != "Concat" || node7->op_type() != "Reshape") + continue; + + if (node_reference.find(node2->output(0)) == node_reference.end() || node_reference[node2->output(0)] != 1) + continue; + +// if (node_reference.find(node3->output(0)) == node_reference.end() || node_reference[node3->output(0)] != 1) +// continue; + + if (node_reference.find(node4->output(0)) == node_reference.end() || node_reference[node4->output(0)] != 1) + continue; + + if (node_reference.find(node5->output(0)) == node_reference.end() || node_reference[node5->output(0)] != 1) + continue; + + if (node_reference.find(node6->output(0)) == node_reference.end() || node_reference[node6->output(0)] != 1) + continue; + + if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) || node5->input(0) != node3->output(0) + || node6->input(0) != node4->output(0) || node6->input(1) != node5->output(0) + || node7->input(0) != node->input(0) || node7->input(1) != node6->output(0)) + continue; + + // axis = 0 + int gather_axis = get_node_attr_i(*node2, "axis"); + if (gather_axis != 0) + continue; + + // indices = 0 + if (weights.find(node2->input(1)) == weights.end()) + continue; + + std::vector gather_indices = get_tensor_proto_reshape_shape(weights[node2->input(1)]); + if (gather_indices.size() != 1 || gather_indices[0] != 0) + continue; + + // axes = (0) + std::vector unsqueeze_axes = get_node_attr_ai(*node4, "axes"); + if (unsqueeze_axes.size() != 1) + continue; + if (unsqueeze_axes[0] != 0) + continue; + + // axes = (0) + std::vector unsqueeze2_axes = get_node_attr_ai(*node5, "axes"); + if (unsqueeze2_axes.size() != 1) + continue; + if (unsqueeze2_axes[0] != 0) + continue; + + // data = -1 + if (weights.find(node5->input(0)) == weights.end()) + continue; + + std::vector unsqueeze2_data = get_tensor_proto_reshape_shape(weights[node5->input(0)]); + if (unsqueeze2_data.size() != 1 || unsqueeze2_data[0] != -1) + continue; + + // axis = 0 + int concat_axis = get_node_attr_i(*node6, "axis"); + if (concat_axis != 0) + continue; + + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); +// node3->set_op_type("noop_reducedncnn"); + node4->set_op_type("noop_reducedncnn"); + node5->set_op_type("noop_reducedncnn"); + node6->set_op_type("noop_reducedncnn"); + + node_reference[node->input(0)] -= 1; + + node_reference.erase(node_reference.find(node->output(0))); + node_reference.erase(node_reference.find(node2->output(0))); +// node_reference.erase(node_reference.find(node3->output(0))); + node_reference.erase(node_reference.find(node4->output(0))); + node_reference.erase(node_reference.find(node5->output(0))); + node_reference.erase(node_reference.find(node6->output(0))); + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); +// blob_names.erase(node3->output(0)); + blob_names.erase(node4->output(0)); + blob_names.erase(node5->output(0)); + blob_names.erase(node6->output(0)); + + node7->set_op_type("Flatten"); + node7->clear_input(); + node7->add_input(node->input(0)); + + reduced_node_count += 5; + i += 5; + } + } +} + int main(int argc, char** argv) { const char* onnxpb = argv[1]; @@ -989,6 +1108,7 @@ int main(int argc, char** argv) fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); fuse_unsqueeze_prelu(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); fuse_normalize (mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); + fuse_flatten (mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); // remove node_reference entry with reference equals to one int splitncnn_blob_count = 0;