Browse Source

onnx fuse flatten

tags/20200106
nihuini 6 years ago
parent
commit
c95e7bbb3c
1 changed files with 121 additions and 1 deletions
  1. +121
    -1
      tools/onnx/onnx2ncnn.cpp

+ 121
- 1
tools/onnx/onnx2ncnn.cpp View File

@@ -714,7 +714,6 @@ static void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map<std::
}
}


static void fuse_normalize(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, onnx::TensorProto>& binaryop_weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count, std::vector<std::string>& reduced_binaryop_weights)
{
int node_count = mutable_graph->node_size();
@@ -794,6 +793,126 @@ static void fuse_normalize(onnx::GraphProto* mutable_graph, std::map<std::string
}
}

static void fuse_flatten(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, onnx::TensorProto>& binaryop_weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count, std::vector<std::string>& reduced_binaryop_weights)
{
int node_count = mutable_graph->node_size();
for (int i=0; i<node_count; i++)
{
onnx::NodeProto* node = mutable_graph->mutable_node(i);

// Flatten <= X - Shape - Gather - Constant - Unsqueeze - Unsqueeze - Concat - Reshape
if (node->op_type() == "Shape")
{
if (node_reference.find(node->output(0)) == node_reference.end() || node_reference[node->output(0)] != 1)
continue;

if (i+6 >= node_count)
continue;

onnx::NodeProto* node2 = mutable_graph->mutable_node(i+1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i+2);
onnx::NodeProto* node4 = mutable_graph->mutable_node(i+3);
onnx::NodeProto* node5 = mutable_graph->mutable_node(i+4);
onnx::NodeProto* node6 = mutable_graph->mutable_node(i+5);
onnx::NodeProto* node7 = mutable_graph->mutable_node(i+6);

if (node2->op_type() != "Gather" || node3->op_type() != "Constant" || node4->op_type() != "Unsqueeze" || node5->op_type() != "Unsqueeze"
|| node6->op_type() != "Concat" || node7->op_type() != "Reshape")
continue;

if (node_reference.find(node2->output(0)) == node_reference.end() || node_reference[node2->output(0)] != 1)
continue;

// if (node_reference.find(node3->output(0)) == node_reference.end() || node_reference[node3->output(0)] != 1)
// continue;

if (node_reference.find(node4->output(0)) == node_reference.end() || node_reference[node4->output(0)] != 1)
continue;

if (node_reference.find(node5->output(0)) == node_reference.end() || node_reference[node5->output(0)] != 1)
continue;

if (node_reference.find(node6->output(0)) == node_reference.end() || node_reference[node6->output(0)] != 1)
continue;

if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) || node5->input(0) != node3->output(0)
|| node6->input(0) != node4->output(0) || node6->input(1) != node5->output(0)
|| node7->input(0) != node->input(0) || node7->input(1) != node6->output(0))
continue;

// axis = 0
int gather_axis = get_node_attr_i(*node2, "axis");
if (gather_axis != 0)
continue;

// indices = 0
if (weights.find(node2->input(1)) == weights.end())
continue;

std::vector<int> gather_indices = get_tensor_proto_reshape_shape(weights[node2->input(1)]);
if (gather_indices.size() != 1 || gather_indices[0] != 0)
continue;

// axes = (0)
std::vector<int> unsqueeze_axes = get_node_attr_ai(*node4, "axes");
if (unsqueeze_axes.size() != 1)
continue;
if (unsqueeze_axes[0] != 0)
continue;

// axes = (0)
std::vector<int> unsqueeze2_axes = get_node_attr_ai(*node5, "axes");
if (unsqueeze2_axes.size() != 1)
continue;
if (unsqueeze2_axes[0] != 0)
continue;

// data = -1
if (weights.find(node5->input(0)) == weights.end())
continue;

std::vector<int> unsqueeze2_data = get_tensor_proto_reshape_shape(weights[node5->input(0)]);
if (unsqueeze2_data.size() != 1 || unsqueeze2_data[0] != -1)
continue;

// axis = 0
int concat_axis = get_node_attr_i(*node6, "axis");
if (concat_axis != 0)
continue;

// reduce
node->set_op_type("noop_reducedncnn");
node2->set_op_type("noop_reducedncnn");
// node3->set_op_type("noop_reducedncnn");
node4->set_op_type("noop_reducedncnn");
node5->set_op_type("noop_reducedncnn");
node6->set_op_type("noop_reducedncnn");

node_reference[node->input(0)] -= 1;

node_reference.erase(node_reference.find(node->output(0)));
node_reference.erase(node_reference.find(node2->output(0)));
// node_reference.erase(node_reference.find(node3->output(0)));
node_reference.erase(node_reference.find(node4->output(0)));
node_reference.erase(node_reference.find(node5->output(0)));
node_reference.erase(node_reference.find(node6->output(0)));
blob_names.erase(node->output(0));
blob_names.erase(node2->output(0));
// blob_names.erase(node3->output(0));
blob_names.erase(node4->output(0));
blob_names.erase(node5->output(0));
blob_names.erase(node6->output(0));

node7->set_op_type("Flatten");
node7->clear_input();
node7->add_input(node->input(0));

reduced_node_count += 5;
i += 5;
}
}
}

int main(int argc, char** argv)
{
const char* onnxpb = argv[1];
@@ -989,6 +1108,7 @@ int main(int argc, char** argv)
fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights);
fuse_unsqueeze_prelu(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights);
fuse_normalize (mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights);
fuse_flatten (mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights);

// remove node_reference entry with reference equals to one
int splitncnn_blob_count = 0;


Loading…
Cancel
Save