diff --git a/tools/onnx/onnx2ncnn.cpp b/tools/onnx/onnx2ncnn.cpp index 0f00d3f1d..b2b2777f5 100644 --- a/tools/onnx/onnx2ncnn.cpp +++ b/tools/onnx/onnx2ncnn.cpp @@ -205,6 +205,7 @@ int main(int argc, char** argv) fprintf(pp, "7767517\n"); const onnx::GraphProto& graph = model.graph(); + onnx::GraphProto* mutable_graph = model.mutable_graph(); int node_count = graph.node_size(); @@ -221,6 +222,8 @@ int main(int argc, char** argv) { const onnx::TensorProto& initializer = graph.initializer(j); +// fprintf(stderr, "weight = %s\n", initializer.name().c_str()); + weights[initializer.name()] = initializer; } @@ -364,6 +367,86 @@ int main(int argc, char** argv) input_node_count++; } + // op chain fusion + int reduced_node_count = 0; + for (int i=0; imutable_node(i); + + // MatMul <= Transpose(weight) - MatMul + if (node->op_type() == "Transpose") + { + // check weight + if (weights.find(node->input(0)) == weights.end()) + continue; + + onnx::TensorProto& B = weights[node->input(0)]; + if (B.dims_size() != 2) + continue; + + if (node_reference[node->output(0)] != 1) + continue; + + // perm = (1, 0) + std::vector perm = get_node_attr_ai(*node, "perm"); + if (perm.size() != 2) + continue; + if (perm[0] != 1 || perm[1] != 0) + continue; + + if (i+1 >= node_count) + continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i+1); + + if (node2->op_type() != "MatMul") + continue; + + // reduce + node->set_op_type("noop_reducedncnn"); + + node_reference.erase(node_reference.find(node->output(0))); + blob_names.erase(node->output(0)); + + node2->set_input(1, node->input(0)); + + // permute weight + { + const int h = B.dims(0); + const int w = B.dims(1); + + std::vector permuted_data; + permuted_data.reserve(h * w); + const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); + + for (int j=0; j::iterator it = node_reference.begin(); @@ -381,7 +464,7 @@ int main(int argc, char** argv) } } - fprintf(pp, "%lu %lu\n", node_count + input_node_count + node_reference.size() + graph.initializer_size() - weights.size(), blob_names.size() + splitncnn_blob_count); + fprintf(pp, "%lu %lu\n", node_count - reduced_node_count + input_node_count + node_reference.size() + graph.initializer_size() - weights.size(), blob_names.size() + splitncnn_blob_count); int internal_split = 0; @@ -436,6 +519,13 @@ int main(int argc, char** argv) const std::string& op = node.op_type(); +// fprintf(stderr, "op = %s\n", op.c_str()); + + if (op == "noop_reducedncnn") + { + continue; + } + std::string name = node.name(); if (name.empty()) {