Browse Source

fuse onnx gather after shufflechannel into split

tags/20200727
nihuini 5 years ago
parent
commit
bcf982cc9d
1 changed files with 79 additions and 0 deletions
  1. +79
    -0
      tools/onnx/onnx2ncnn.cpp

+ 79
- 0
tools/onnx/onnx2ncnn.cpp View File

@@ -432,6 +432,84 @@ static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map<std::s
}
}

static void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, onnx::TensorProto>& binaryop_weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count, std::vector<std::string>& reduced_binaryop_weights)
{
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++)
{
onnx::NodeProto* node = mutable_graph->mutable_node(i);

// Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1)
if (node->op_type() == "ShuffleChannel")
{
// reverse = 1
int reverse = get_node_attr_i(*node, "reverse");
if (reverse != 1)
continue;

if (i + 2 >= node_count)
continue;

onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);

if (node2->op_type() != "Gather" || node3->op_type() != "Gather")
continue;

if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0))
continue;

// axis = 0
int gather2_axis = get_node_attr_i(*node2, "axis");
if (gather2_axis != 0)
continue;

// indices = 0
if (weights.find(node2->input(1)) == weights.end())
continue;

std::vector<int> gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
if (gather2_indices.size() != 1 || gather2_indices[0] != 0)
continue;

// axis = 0
int gather3_axis = get_node_attr_i(*node3, "axis");
if (gather3_axis != 0)
continue;

// indices = 1
if (weights.find(node3->input(1)) == weights.end())
continue;

std::vector<int> gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]);
if (gather3_indices.size() != 1 || gather3_indices[0] != 1)
continue;

// reduce
node2->set_op_type("noop_reducedncnn");

node_reference[node->output(0)] -= 1;

node_reference.erase(node_reference.find(node2->output(0)));
blob_names.erase(node2->output(0));

node3->set_op_type("Split");
node3->clear_input();
node3->add_input(node->output(0));
node3->add_output(node3->output(0));
node3->set_output(0, node2->output(0));

node3->clear_attribute();
onnx::AttributeProto* attr_axis = node3->add_attribute();
attr_axis->set_name("axis");
attr_axis->set_i(1);

reduced_node_count += 1;
i += 1;
}
}
}

static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, onnx::TensorProto>& binaryop_weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count, std::vector<std::string>& reduced_binaryop_weights)
{
int node_count = mutable_graph->node_size();
@@ -1383,6 +1461,7 @@ int main(int argc, char** argv)
std::vector<std::string> reduced_binaryop_weights;
fuse_matmul(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights);
fuse_shufflechannel(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights);
fuse_shufflechannel_split(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights);
fuse_hardsigmoid(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights);
fuse_hardswish(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights);
fuse_swish(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights);


Loading…
Cancel
Save