Browse Source

fuse transpose(weight)+matmul to matmul, fix #620

tags/20190320
nihuini 7 years ago
parent
commit
3b404dcedd
1 changed files with 91 additions and 1 deletions
  1. +91
    -1
      tools/onnx/onnx2ncnn.cpp

+ 91
- 1
tools/onnx/onnx2ncnn.cpp View File

@@ -205,6 +205,7 @@ int main(int argc, char** argv)
fprintf(pp, "7767517\n");

const onnx::GraphProto& graph = model.graph();
onnx::GraphProto* mutable_graph = model.mutable_graph();

int node_count = graph.node_size();

@@ -221,6 +222,8 @@ int main(int argc, char** argv)
{
const onnx::TensorProto& initializer = graph.initializer(j);

// fprintf(stderr, "weight = %s\n", initializer.name().c_str());

weights[initializer.name()] = initializer;
}

@@ -364,6 +367,86 @@ int main(int argc, char** argv)
input_node_count++;
}

// op chain fusion
int reduced_node_count = 0;
for (int i=0; i<node_count; i++)
{
onnx::NodeProto* node = mutable_graph->mutable_node(i);

// MatMul <= Transpose(weight) - MatMul
if (node->op_type() == "Transpose")
{
// check weight
if (weights.find(node->input(0)) == weights.end())
continue;

onnx::TensorProto& B = weights[node->input(0)];
if (B.dims_size() != 2)
continue;

if (node_reference[node->output(0)] != 1)
continue;

// perm = (1, 0)
std::vector<int> perm = get_node_attr_ai(*node, "perm");
if (perm.size() != 2)
continue;
if (perm[0] != 1 || perm[1] != 0)
continue;

if (i+1 >= node_count)
continue;

onnx::NodeProto* node2 = mutable_graph->mutable_node(i+1);

if (node2->op_type() != "MatMul")
continue;

// reduce
node->set_op_type("noop_reducedncnn");

node_reference.erase(node_reference.find(node->output(0)));
blob_names.erase(node->output(0));

node2->set_input(1, node->input(0));

// permute weight
{
const int h = B.dims(0);
const int w = B.dims(1);

std::vector<float> permuted_data;
permuted_data.reserve(h * w);
const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();

for (int j=0; j<w; j++)
{
for (int k=0; k<h; k++)
{
float vb = bptr[ k*w + j ];
permuted_data.push_back(vb);
}
}

B.set_dims(0, w);
B.set_dims(1, h);

if (B.has_raw_data())
{
B.set_raw_data(permuted_data.data(), permuted_data.size() * sizeof(float));
}
else
{
for (int j=0; j<(int)permuted_data.size(); j++)
B.set_float_data(j, permuted_data[j]);
}
}

reduced_node_count += 1;
i += 1;
}
}

// remove node_reference entry with reference equals to one
int splitncnn_blob_count = 0;
std::map<std::string, int>::iterator it = node_reference.begin();
@@ -381,7 +464,7 @@ int main(int argc, char** argv)
}
}

fprintf(pp, "%lu %lu\n", node_count + input_node_count + node_reference.size() + graph.initializer_size() - weights.size(), blob_names.size() + splitncnn_blob_count);
fprintf(pp, "%lu %lu\n", node_count - reduced_node_count + input_node_count + node_reference.size() + graph.initializer_size() - weights.size(), blob_names.size() + splitncnn_blob_count);

int internal_split = 0;

@@ -436,6 +519,13 @@ int main(int argc, char** argv)

const std::string& op = node.op_type();

// fprintf(stderr, "op = %s\n", op.c_str());

if (op == "noop_reducedncnn")
{
continue;
}

std::string name = node.name();
if (name.empty())
{


Loading…
Cancel
Save