From dbef56df477cb8a80ccafa183144c6123f577f30 Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 30 Apr 2019 11:57:10 +0800 Subject: [PATCH] write graph --- tools/pytorch/pytorch2ncnn.cpp | 184 +++++++++++++++++++++++++++++++-- 1 file changed, 178 insertions(+), 6 deletions(-) diff --git a/tools/pytorch/pytorch2ncnn.cpp b/tools/pytorch/pytorch2ncnn.cpp index 1d51d493e..8c1c57f16 100644 --- a/tools/pytorch/pytorch2ncnn.cpp +++ b/tools/pytorch/pytorch2ncnn.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -439,6 +440,8 @@ class PyTorchNode { public: std::string op; + std::string name; + std::vector inputs; std::vector outputs; std::vector args; }; @@ -563,6 +566,8 @@ static int read_code(const std::string& code, std::vector& nodes) { nodes.clear(); + int internal_unknown = 0; + // read code line bool forward_input = false; @@ -595,8 +600,24 @@ static int read_code(const std::string& code, std::vector& nodes) if (nscan != 1) continue; - fprintf(stderr, "netinput = %s\n", netinput); +// fprintf(stderr, "netinput = %s\n", netinput); forward_input = false; + + PyTorchNode n; + n.op = "Input"; + n.outputs.push_back(netinput); + + { + // assign default unknown name + char unknownname[256]; + sprintf(unknownname, "unknownncnn_%d", internal_unknown); + + n.name = unknownname; + + internal_unknown++; + } + + nodes.push_back(n); } else if (strstr(line, " = torch.")) { @@ -627,6 +648,17 @@ static int read_code(const std::string& code, std::vector& nodes) n.op = op; n.outputs = parse_op_output_list(outputs); n.args = parse_op_arg_list(op_args.c_str()); + + { + // assign default unknown name + char unknownname[256]; + sprintf(unknownname, "unknownncnn_%d", internal_unknown); + + n.name = unknownname; + + internal_unknown++; + } + nodes.push_back(n); } else if (strstr(line, "return torch.")) @@ -653,8 +685,19 @@ static int read_code(const std::string& code, std::vector& nodes) PyTorchNode n; n.op = op; - n.outputs = parse_op_output_list("ncnnoutput_0"); + n.outputs.push_back("ncnnoutput_0"); n.args = parse_op_arg_list(op_args.c_str()); + + { + // assign default unknown name + char unknownname[256]; + sprintf(unknownname, "unknownncnn_%d", internal_unknown); + + n.name = unknownname; + + internal_unknown++; + } + nodes.push_back(n); } } @@ -665,6 +708,8 @@ static int read_code(const std::string& code, std::vector& nodes) int main(int argc, char** argv) { const char* ptpath = "model_base.pt"; + const char* ncnn_prototxt = "ncnn.param"; + const char* ncnn_modelbin = "ncnn.bin"; std::string code; std::string model_json; @@ -682,26 +727,153 @@ int main(int argc, char** argv) std::vector nodes; read_code(code, nodes); + FILE* pp = fopen(ncnn_prototxt, "wb"); + FILE* bp = fopen(ncnn_modelbin, "wb"); + + // magic + fprintf(pp, "7767517\n"); + const int node_count = nodes.size(); fprintf(stderr, "node_count = %d\n", node_count); + // node reference + std::map node_reference; + + // weight node + std::vector weight_nodes; + + // global definition line + // [layer count] [blob count] + std::set blob_names; + for (int i=0; i op_arg_list = n.args; + std::vector input_list; + std::vector arg_list; + for (int i=0; i<(int)op_arg_list.size(); i++) + { + const std::string& arg = op_arg_list[i]; + if (blob_names.find(arg) == blob_names.end()) + { + arg_list.push_back(arg); + } + else + { + input_list.push_back(arg); + + if (node_reference.find(arg) == node_reference.end()) + { + node_reference[arg] = 1; + } + else + { + node_reference[arg] = node_reference[arg] + 1; + } + } + } + + n.inputs = input_list; + n.args = arg_list; + } + + // remove node_reference entry with reference equals to one + int splitncnn_blob_count = 0; + std::map::iterator it = node_reference.begin(); + while (it != node_reference.end()) + { + if (it->second == 1) + { + node_reference.erase(it++); + } + else + { + splitncnn_blob_count += it->second; +// fprintf(stderr, "%s %d\n", it->first.c_str(), it->second); + ++it; + } + } + + fprintf(pp, "%lu %lu\n", node_count + node_reference.size(), blob_names.size() + splitncnn_blob_count); + + int internal_split = 0; + for (int i=0; i 1) + { + char splitname[256]; + sprintf(splitname, "splitncnn_%d", internal_split); + fprintf(pp, "%-24s %-24s %d %d", "Split", splitname, 1, refcount); + + fprintf(pp, " %s", output_name.c_str()); + + for (int k=0; k