From bcf982cc9d9246e7b83a75ef97dbcfee391e4fa3 Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 22 Jun 2020 16:19:16 +0800 Subject: [PATCH] fuse onnx gather after shufflechannel into split --- tools/onnx/onnx2ncnn.cpp | 79 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/tools/onnx/onnx2ncnn.cpp b/tools/onnx/onnx2ncnn.cpp index 0a0f654ac..899b66b35 100644 --- a/tools/onnx/onnx2ncnn.cpp +++ b/tools/onnx/onnx2ncnn.cpp @@ -432,6 +432,84 @@ static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1) + if (node->op_type() == "ShuffleChannel") + { + // reverse = 1 + int reverse = get_node_attr_i(*node, "reverse"); + if (reverse != 1) + continue; + + if (i + 2 >= node_count) + continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + + if (node2->op_type() != "Gather" || node3->op_type() != "Gather") + continue; + + if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0)) + continue; + + // axis = 0 + int gather2_axis = get_node_attr_i(*node2, "axis"); + if (gather2_axis != 0) + continue; + + // indices = 0 + if (weights.find(node2->input(1)) == weights.end()) + continue; + + std::vector gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]); + if (gather2_indices.size() != 1 || gather2_indices[0] != 0) + continue; + + // axis = 0 + int gather3_axis = get_node_attr_i(*node3, "axis"); + if (gather3_axis != 0) + continue; + + // indices = 1 + if (weights.find(node3->input(1)) == weights.end()) + continue; + + std::vector gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]); + if (gather3_indices.size() != 1 || gather3_indices[0] != 1) + continue; + + // reduce + node2->set_op_type("noop_reducedncnn"); + + node_reference[node->output(0)] -= 1; + + node_reference.erase(node_reference.find(node2->output(0))); + blob_names.erase(node2->output(0)); + + node3->set_op_type("Split"); + node3->clear_input(); + node3->add_input(node->output(0)); + node3->add_output(node3->output(0)); + node3->set_output(0, node2->output(0)); + + node3->clear_attribute(); + onnx::AttributeProto* attr_axis = node3->add_attribute(); + attr_axis->set_name("axis"); + attr_axis->set_i(1); + + reduced_node_count += 1; + i += 1; + } + } +} + static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) { int node_count = mutable_graph->node_size(); @@ -1383,6 +1461,7 @@ int main(int argc, char** argv) std::vector reduced_binaryop_weights; fuse_matmul(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); fuse_shufflechannel(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); + fuse_shufflechannel_split(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); fuse_hardsigmoid(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); fuse_hardswish(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); fuse_swish(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights);