| @@ -205,6 +205,7 @@ int main(int argc, char** argv) | |||
| fprintf(pp, "7767517\n"); | |||
| const onnx::GraphProto& graph = model.graph(); | |||
| onnx::GraphProto* mutable_graph = model.mutable_graph(); | |||
| int node_count = graph.node_size(); | |||
| @@ -221,6 +222,8 @@ int main(int argc, char** argv) | |||
| { | |||
| const onnx::TensorProto& initializer = graph.initializer(j); | |||
| // fprintf(stderr, "weight = %s\n", initializer.name().c_str()); | |||
| weights[initializer.name()] = initializer; | |||
| } | |||
| @@ -364,6 +367,86 @@ int main(int argc, char** argv) | |||
| input_node_count++; | |||
| } | |||
| // op chain fusion | |||
| int reduced_node_count = 0; | |||
| for (int i=0; i<node_count; i++) | |||
| { | |||
| onnx::NodeProto* node = mutable_graph->mutable_node(i); | |||
| // MatMul <= Transpose(weight) - MatMul | |||
| if (node->op_type() == "Transpose") | |||
| { | |||
| // check weight | |||
| if (weights.find(node->input(0)) == weights.end()) | |||
| continue; | |||
| onnx::TensorProto& B = weights[node->input(0)]; | |||
| if (B.dims_size() != 2) | |||
| continue; | |||
| if (node_reference[node->output(0)] != 1) | |||
| continue; | |||
| // perm = (1, 0) | |||
| std::vector<int> perm = get_node_attr_ai(*node, "perm"); | |||
| if (perm.size() != 2) | |||
| continue; | |||
| if (perm[0] != 1 || perm[1] != 0) | |||
| continue; | |||
| if (i+1 >= node_count) | |||
| continue; | |||
| onnx::NodeProto* node2 = mutable_graph->mutable_node(i+1); | |||
| if (node2->op_type() != "MatMul") | |||
| continue; | |||
| // reduce | |||
| node->set_op_type("noop_reducedncnn"); | |||
| node_reference.erase(node_reference.find(node->output(0))); | |||
| blob_names.erase(node->output(0)); | |||
| node2->set_input(1, node->input(0)); | |||
| // permute weight | |||
| { | |||
| const int h = B.dims(0); | |||
| const int w = B.dims(1); | |||
| std::vector<float> permuted_data; | |||
| permuted_data.reserve(h * w); | |||
| const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); | |||
| for (int j=0; j<w; j++) | |||
| { | |||
| for (int k=0; k<h; k++) | |||
| { | |||
| float vb = bptr[ k*w + j ]; | |||
| permuted_data.push_back(vb); | |||
| } | |||
| } | |||
| B.set_dims(0, w); | |||
| B.set_dims(1, h); | |||
| if (B.has_raw_data()) | |||
| { | |||
| B.set_raw_data(permuted_data.data(), permuted_data.size() * sizeof(float)); | |||
| } | |||
| else | |||
| { | |||
| for (int j=0; j<(int)permuted_data.size(); j++) | |||
| B.set_float_data(j, permuted_data[j]); | |||
| } | |||
| } | |||
| reduced_node_count += 1; | |||
| i += 1; | |||
| } | |||
| } | |||
| // remove node_reference entry with reference equals to one | |||
| int splitncnn_blob_count = 0; | |||
| std::map<std::string, int>::iterator it = node_reference.begin(); | |||
| @@ -381,7 +464,7 @@ int main(int argc, char** argv) | |||
| } | |||
| } | |||
| fprintf(pp, "%lu %lu\n", node_count + input_node_count + node_reference.size() + graph.initializer_size() - weights.size(), blob_names.size() + splitncnn_blob_count); | |||
| fprintf(pp, "%lu %lu\n", node_count - reduced_node_count + input_node_count + node_reference.size() + graph.initializer_size() - weights.size(), blob_names.size() + splitncnn_blob_count); | |||
| int internal_split = 0; | |||
| @@ -436,6 +519,13 @@ int main(int argc, char** argv) | |||
| const std::string& op = node.op_type(); | |||
| // fprintf(stderr, "op = %s\n", op.c_str()); | |||
| if (op == "noop_reducedncnn") | |||
| { | |||
| continue; | |||
| } | |||
| std::string name = node.name(); | |||
| if (name.empty()) | |||
| { | |||