| @@ -714,7 +714,6 @@ static void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map<std:: | |||
| } | |||
| } | |||
| static void fuse_normalize(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, onnx::TensorProto>& binaryop_weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count, std::vector<std::string>& reduced_binaryop_weights) | |||
| { | |||
| int node_count = mutable_graph->node_size(); | |||
| @@ -794,6 +793,126 @@ static void fuse_normalize(onnx::GraphProto* mutable_graph, std::map<std::string | |||
| } | |||
| } | |||
| static void fuse_flatten(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, onnx::TensorProto>& binaryop_weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count, std::vector<std::string>& reduced_binaryop_weights) | |||
| { | |||
| int node_count = mutable_graph->node_size(); | |||
| for (int i=0; i<node_count; i++) | |||
| { | |||
| onnx::NodeProto* node = mutable_graph->mutable_node(i); | |||
| // Flatten <= X - Shape - Gather - Constant - Unsqueeze - Unsqueeze - Concat - Reshape | |||
| if (node->op_type() == "Shape") | |||
| { | |||
| if (node_reference.find(node->output(0)) == node_reference.end() || node_reference[node->output(0)] != 1) | |||
| continue; | |||
| if (i+6 >= node_count) | |||
| continue; | |||
| onnx::NodeProto* node2 = mutable_graph->mutable_node(i+1); | |||
| onnx::NodeProto* node3 = mutable_graph->mutable_node(i+2); | |||
| onnx::NodeProto* node4 = mutable_graph->mutable_node(i+3); | |||
| onnx::NodeProto* node5 = mutable_graph->mutable_node(i+4); | |||
| onnx::NodeProto* node6 = mutable_graph->mutable_node(i+5); | |||
| onnx::NodeProto* node7 = mutable_graph->mutable_node(i+6); | |||
| if (node2->op_type() != "Gather" || node3->op_type() != "Constant" || node4->op_type() != "Unsqueeze" || node5->op_type() != "Unsqueeze" | |||
| || node6->op_type() != "Concat" || node7->op_type() != "Reshape") | |||
| continue; | |||
| if (node_reference.find(node2->output(0)) == node_reference.end() || node_reference[node2->output(0)] != 1) | |||
| continue; | |||
| // if (node_reference.find(node3->output(0)) == node_reference.end() || node_reference[node3->output(0)] != 1) | |||
| // continue; | |||
| if (node_reference.find(node4->output(0)) == node_reference.end() || node_reference[node4->output(0)] != 1) | |||
| continue; | |||
| if (node_reference.find(node5->output(0)) == node_reference.end() || node_reference[node5->output(0)] != 1) | |||
| continue; | |||
| if (node_reference.find(node6->output(0)) == node_reference.end() || node_reference[node6->output(0)] != 1) | |||
| continue; | |||
| if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) || node5->input(0) != node3->output(0) | |||
| || node6->input(0) != node4->output(0) || node6->input(1) != node5->output(0) | |||
| || node7->input(0) != node->input(0) || node7->input(1) != node6->output(0)) | |||
| continue; | |||
| // axis = 0 | |||
| int gather_axis = get_node_attr_i(*node2, "axis"); | |||
| if (gather_axis != 0) | |||
| continue; | |||
| // indices = 0 | |||
| if (weights.find(node2->input(1)) == weights.end()) | |||
| continue; | |||
| std::vector<int> gather_indices = get_tensor_proto_reshape_shape(weights[node2->input(1)]); | |||
| if (gather_indices.size() != 1 || gather_indices[0] != 0) | |||
| continue; | |||
| // axes = (0) | |||
| std::vector<int> unsqueeze_axes = get_node_attr_ai(*node4, "axes"); | |||
| if (unsqueeze_axes.size() != 1) | |||
| continue; | |||
| if (unsqueeze_axes[0] != 0) | |||
| continue; | |||
| // axes = (0) | |||
| std::vector<int> unsqueeze2_axes = get_node_attr_ai(*node5, "axes"); | |||
| if (unsqueeze2_axes.size() != 1) | |||
| continue; | |||
| if (unsqueeze2_axes[0] != 0) | |||
| continue; | |||
| // data = -1 | |||
| if (weights.find(node5->input(0)) == weights.end()) | |||
| continue; | |||
| std::vector<int> unsqueeze2_data = get_tensor_proto_reshape_shape(weights[node5->input(0)]); | |||
| if (unsqueeze2_data.size() != 1 || unsqueeze2_data[0] != -1) | |||
| continue; | |||
| // axis = 0 | |||
| int concat_axis = get_node_attr_i(*node6, "axis"); | |||
| if (concat_axis != 0) | |||
| continue; | |||
| // reduce | |||
| node->set_op_type("noop_reducedncnn"); | |||
| node2->set_op_type("noop_reducedncnn"); | |||
| // node3->set_op_type("noop_reducedncnn"); | |||
| node4->set_op_type("noop_reducedncnn"); | |||
| node5->set_op_type("noop_reducedncnn"); | |||
| node6->set_op_type("noop_reducedncnn"); | |||
| node_reference[node->input(0)] -= 1; | |||
| node_reference.erase(node_reference.find(node->output(0))); | |||
| node_reference.erase(node_reference.find(node2->output(0))); | |||
| // node_reference.erase(node_reference.find(node3->output(0))); | |||
| node_reference.erase(node_reference.find(node4->output(0))); | |||
| node_reference.erase(node_reference.find(node5->output(0))); | |||
| node_reference.erase(node_reference.find(node6->output(0))); | |||
| blob_names.erase(node->output(0)); | |||
| blob_names.erase(node2->output(0)); | |||
| // blob_names.erase(node3->output(0)); | |||
| blob_names.erase(node4->output(0)); | |||
| blob_names.erase(node5->output(0)); | |||
| blob_names.erase(node6->output(0)); | |||
| node7->set_op_type("Flatten"); | |||
| node7->clear_input(); | |||
| node7->add_input(node->input(0)); | |||
| reduced_node_count += 5; | |||
| i += 5; | |||
| } | |||
| } | |||
| } | |||
| int main(int argc, char** argv) | |||
| { | |||
| const char* onnxpb = argv[1]; | |||
| @@ -989,6 +1108,7 @@ int main(int argc, char** argv) | |||
| fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | |||
| fuse_unsqueeze_prelu(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | |||
| fuse_normalize (mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | |||
| fuse_flatten (mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | |||
| // remove node_reference entry with reference equals to one | |||
| int splitncnn_blob_count = 0; | |||