Browse Source

write graph

tags/20190611
nihuini 7 years ago
parent
commit
dbef56df47
1 changed files with 178 additions and 6 deletions
  1. +178
    -6
      tools/pytorch/pytorch2ncnn.cpp

+ 178
- 6
tools/pytorch/pytorch2ncnn.cpp View File

@@ -17,6 +17,7 @@
#include <string.h>

#include <map>
#include <set>
#include <string>
#include <vector>

@@ -439,6 +440,8 @@ class PyTorchNode
{
public:
std::string op;
std::string name;
std::vector<std::string> inputs;
std::vector<std::string> outputs;
std::vector<std::string> args;
};
@@ -563,6 +566,8 @@ static int read_code(const std::string& code, std::vector<PyTorchNode>& nodes)
{
nodes.clear();

int internal_unknown = 0;

// read code line
bool forward_input = false;

@@ -595,8 +600,24 @@ static int read_code(const std::string& code, std::vector<PyTorchNode>& nodes)
if (nscan != 1)
continue;

fprintf(stderr, "netinput = %s\n", netinput);
// fprintf(stderr, "netinput = %s\n", netinput);
forward_input = false;

PyTorchNode n;
n.op = "Input";
n.outputs.push_back(netinput);

{
// assign default unknown name
char unknownname[256];
sprintf(unknownname, "unknownncnn_%d", internal_unknown);

n.name = unknownname;

internal_unknown++;
}

nodes.push_back(n);
}
else if (strstr(line, " = torch."))
{
@@ -627,6 +648,17 @@ static int read_code(const std::string& code, std::vector<PyTorchNode>& nodes)
n.op = op;
n.outputs = parse_op_output_list(outputs);
n.args = parse_op_arg_list(op_args.c_str());

{
// assign default unknown name
char unknownname[256];
sprintf(unknownname, "unknownncnn_%d", internal_unknown);

n.name = unknownname;

internal_unknown++;
}

nodes.push_back(n);
}
else if (strstr(line, "return torch."))
@@ -653,8 +685,19 @@ static int read_code(const std::string& code, std::vector<PyTorchNode>& nodes)

PyTorchNode n;
n.op = op;
n.outputs = parse_op_output_list("ncnnoutput_0");
n.outputs.push_back("ncnnoutput_0");
n.args = parse_op_arg_list(op_args.c_str());

{
// assign default unknown name
char unknownname[256];
sprintf(unknownname, "unknownncnn_%d", internal_unknown);

n.name = unknownname;

internal_unknown++;
}

nodes.push_back(n);
}
}
@@ -665,6 +708,8 @@ static int read_code(const std::string& code, std::vector<PyTorchNode>& nodes)
int main(int argc, char** argv)
{
const char* ptpath = "model_base.pt";
const char* ncnn_prototxt = "ncnn.param";
const char* ncnn_modelbin = "ncnn.bin";

std::string code;
std::string model_json;
@@ -682,26 +727,153 @@ int main(int argc, char** argv)
std::vector<PyTorchNode> nodes;
read_code(code, nodes);

FILE* pp = fopen(ncnn_prototxt, "wb");
FILE* bp = fopen(ncnn_modelbin, "wb");

// magic
fprintf(pp, "7767517\n");

const int node_count = nodes.size();

fprintf(stderr, "node_count = %d\n", node_count);

// node reference
std::map<std::string, int> node_reference;

// weight node
std::vector<int> weight_nodes;

// global definition line
// [layer count] [blob count]
std::set<std::string> blob_names;
for (int i=0; i<node_count; i++)
{
PyTorchNode& n = nodes[i];

for (int j=0; j<(int)n.outputs.size(); j++)
{
blob_names.insert(n.outputs[j]);
}

// distinguish weights and inputs
std::vector<std::string> op_arg_list = n.args;
std::vector<std::string> input_list;
std::vector<std::string> arg_list;
for (int i=0; i<(int)op_arg_list.size(); i++)
{
const std::string& arg = op_arg_list[i];
if (blob_names.find(arg) == blob_names.end())
{
arg_list.push_back(arg);
}
else
{
input_list.push_back(arg);

if (node_reference.find(arg) == node_reference.end())
{
node_reference[arg] = 1;
}
else
{
node_reference[arg] = node_reference[arg] + 1;
}
}
}

n.inputs = input_list;
n.args = arg_list;
}

// remove node_reference entry with reference equals to one
int splitncnn_blob_count = 0;
std::map<std::string, int>::iterator it = node_reference.begin();
while (it != node_reference.end())
{
if (it->second == 1)
{
node_reference.erase(it++);
}
else
{
splitncnn_blob_count += it->second;
// fprintf(stderr, "%s %d\n", it->first.c_str(), it->second);
++it;
}
}

fprintf(pp, "%lu %lu\n", node_count + node_reference.size(), blob_names.size() + splitncnn_blob_count);

int internal_split = 0;

for (int i=0; i<node_count; i++)
{
const PyTorchNode& n = nodes[i];

fprintf(stderr, "op = %s\n", n.op.c_str());
fprintf(pp, "%-24s", n.op.c_str());

fprintf(pp, " %-24s %d %d", n.name.c_str(), (int)n.inputs.size(), (int)n.outputs.size());

for (int j=0; j<(int)n.inputs.size(); j++)
{
std::string input_name = n.inputs[j];

if (node_reference.find(input_name) != node_reference.end())
{
int refidx = node_reference[input_name] - 1;
node_reference[input_name] = refidx;

char splitsuffix[256];
sprintf(splitsuffix, "_splitncnn_%d", refidx);
input_name = input_name + splitsuffix;
}

fprintf(pp, " %s", input_name.c_str());
}

for (int j=0; j<(int)n.outputs.size(); j++)
{
fprintf(stderr, "output = %s\n", n.outputs[j].c_str());
fprintf(pp, " %s", n.outputs[j].c_str());
}

// TODO op specific params
{
for (int j=0; j<(int)n.args.size(); j++)
{
fprintf(pp, " %s", n.args[j].c_str());
}
}

for (int j=0; j<(int)n.args.size(); j++)
fprintf(pp, "\n");

for (int j=0; j<(int)n.outputs.size(); j++)
{
fprintf(stderr, "arg = %s\n", n.args[j].c_str());
const std::string& output_name = n.outputs[j];
if (node_reference.find(output_name) != node_reference.end())
{
int refcount = node_reference[output_name];
if (refcount > 1)
{
char splitname[256];
sprintf(splitname, "splitncnn_%d", internal_split);
fprintf(pp, "%-24s %-24s %d %d", "Split", splitname, 1, refcount);

fprintf(pp, " %s", output_name.c_str());

for (int k=0; k<refcount; k++)
{
fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k);
}
fprintf(pp, "\n");

internal_split++;
}
}
}
}

fclose(pp);
fclose(bp);

return 0;
}

Loading…
Cancel
Save